mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-14 07:35:41 +08:00
Compare commits
8 Commits
main
...
enable-to-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
661e1e9f23 | ||
|
|
dae88ad606 | ||
|
|
ce0572701e | ||
|
|
5d074af1a3 | ||
|
|
a3f33a1968 | ||
|
|
33169a17d7 | ||
|
|
a88fdbb031 | ||
|
|
1e2735e775 |
@@ -346,8 +346,6 @@
|
||||
title: Flux2Transformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/glm_image_transformer2d
|
||||
title: GlmImageTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
@@ -542,8 +540,6 @@
|
||||
title: Flux2
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/glm_image
|
||||
title: GLM-Image
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# GlmImageTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO).
|
||||
|
||||
## GlmImageTransformer2DModel
|
||||
|
||||
[[autodoc]] GlmImageTransformer2DModel
|
||||
@@ -1,95 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# GLM-Image
|
||||
|
||||
## Overview
|
||||
|
||||
GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios.
|
||||
|
||||
Model architecture: a hybrid autoregressive + diffusion decoder design、
|
||||
|
||||
+ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library.
|
||||
+ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images.
|
||||
|
||||
Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality.
|
||||
|
||||
+ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness.
|
||||
+ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering.
|
||||
|
||||
GLM-Image supports both text-to-image and image-to-image generation within a single model
|
||||
|
||||
+ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios.
|
||||
+ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects.
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image).
|
||||
|
||||
## Usage examples
|
||||
|
||||
### Text to Image Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.pipelines.glm_image import GlmImagePipeline
|
||||
|
||||
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
|
||||
prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy."
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
height=32 * 32,
|
||||
width=36 * 32,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=1.5,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).images[0]
|
||||
|
||||
image.save("output_t2i.png")
|
||||
```
|
||||
|
||||
### Image to Image Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.pipelines.glm_image import GlmImagePipeline
|
||||
from PIL import Image
|
||||
|
||||
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
|
||||
image_path = "cond.jpg"
|
||||
prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator."
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1]
|
||||
height=33 * 32,
|
||||
width=32 * 32,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=1.5,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).images[0]
|
||||
|
||||
image.save("output_i2i.png")
|
||||
```
|
||||
|
||||
+ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model.
|
||||
|
||||
## GlmImagePipeline
|
||||
|
||||
[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## GlmImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput
|
||||
@@ -314,6 +314,25 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
|
||||
```
|
||||
|
||||
### parallel_config
|
||||
|
||||
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||
|
||||
```py
|
||||
CKPT_ID = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
cp_config = ContextParallelConfig(ring_degree=2)
|
||||
transformer = AutoModel.from_pretrained(
|
||||
CKPT_ID,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
parallel_config=cp_config
|
||||
)
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
```
|
||||
### Unified Attention
|
||||
|
||||
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
|
||||
@@ -341,24 +360,4 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](
|
||||
| ring | 13076.492 | 3.82 | 56.02 |
|
||||
| unified_balanced | 11068.705 | 4.52 | 33.85 |
|
||||
|
||||
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
|
||||
|
||||
### parallel_config
|
||||
|
||||
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||
|
||||
```py
|
||||
CKPT_ID = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
cp_config = ContextParallelConfig(ring_degree=2)
|
||||
transformer = AutoModel.from_pretrained(
|
||||
CKPT_ID,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
parallel_config=cp_config
|
||||
)
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
```
|
||||
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to number of attention-heads, a limitation that is solved by unified attention.
|
||||
|
||||
@@ -23,7 +23,6 @@ from .utils import (
|
||||
is_torchao_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -226,7 +225,6 @@ else:
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"GlmImageTransformer2DModel",
|
||||
"HiDreamImageTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -494,7 +492,6 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"GlmImagePipeline",
|
||||
"HiDreamImagePipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -982,7 +979,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -1220,7 +1216,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
GlmImagePipeline,
|
||||
HiDreamImagePipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
|
||||
@@ -98,7 +98,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
@@ -209,7 +208,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
|
||||
@@ -1360,12 +1360,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `8-bit` quantized models. "
|
||||
" Please use the model as it is, since the model has already been set to the correct devices."
|
||||
"Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
|
||||
)
|
||||
elif is_bitsandbytes_version("<", "0.43.2"):
|
||||
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
@@ -1412,17 +1412,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
)
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
|
||||
raise ValueError(
|
||||
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||
"Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
|
||||
)
|
||||
elif is_bitsandbytes_version("<", "0.43.2"):
|
||||
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
|
||||
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
|
||||
logger.warning(
|
||||
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
|
||||
|
||||
@@ -27,7 +27,6 @@ if is_torch_available():
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_glm_image import GlmImageTransformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
|
||||
|
||||
@@ -1,621 +0,0 @@
|
||||
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LayerNorm, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GlmImageCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
|
||||
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
hidden_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
|
||||
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
|
||||
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
|
||||
|
||||
# (B, 2 * condition_dim)
|
||||
condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1)
|
||||
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
||||
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
||||
|
||||
conditioning = timesteps_emb + condition_emb
|
||||
conditioning = F.silu(conditioning)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class GlmImageImageProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
hidden_size: int = 2560,
|
||||
patch_size: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
post_patch_height = height // self.patch_size
|
||||
post_patch_width = width // self.patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GlmImageAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = hidden_states.dtype
|
||||
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
|
||||
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
|
||||
|
||||
emb = self.linear(temb)
|
||||
(
|
||||
shift_msa,
|
||||
c_shift_msa,
|
||||
scale_msa,
|
||||
c_scale_msa,
|
||||
gate_msa,
|
||||
c_gate_msa,
|
||||
shift_mlp,
|
||||
c_shift_mlp,
|
||||
scale_mlp,
|
||||
c_scale_mlp,
|
||||
gate_mlp,
|
||||
c_gate_mlp,
|
||||
) = emb.chunk(12, dim=1)
|
||||
|
||||
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
)
|
||||
|
||||
|
||||
class GlmImageLayerKVCache:
|
||||
"""KV cache for GlmImage model."""
|
||||
|
||||
def __init__(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode: Optional[str] = None # "write", "read", "skip"
|
||||
|
||||
def store(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache is None:
|
||||
self.k_cache = k
|
||||
self.v_cache = v
|
||||
else:
|
||||
self.k_cache = torch.cat([self.k_cache, k], dim=1)
|
||||
self.v_cache = torch.cat([self.v_cache, v], dim=1)
|
||||
|
||||
def get(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache.shape[0] != k.shape[0]:
|
||||
k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
|
||||
v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
|
||||
else:
|
||||
k_cache_expanded = self.k_cache
|
||||
v_cache_expanded = self.v_cache
|
||||
|
||||
k_cache = torch.cat([k_cache_expanded, k], dim=1)
|
||||
v_cache = torch.cat([v_cache_expanded, v], dim=1)
|
||||
return k_cache, v_cache
|
||||
|
||||
def clear(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode = None
|
||||
|
||||
|
||||
class GlmImageKVCache:
|
||||
"""Container for all layers' KV caches."""
|
||||
|
||||
def __init__(self, num_layers: int):
|
||||
self.num_layers = num_layers
|
||||
self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:
|
||||
return self.caches[layer_idx]
|
||||
|
||||
def set_mode(self, mode: Optional[str]):
|
||||
if mode is not None and mode not in ["write", "read", "skip"]:
|
||||
raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'")
|
||||
for cache in self.caches:
|
||||
cache.mode = mode
|
||||
|
||||
def clear(self):
|
||||
for cache in self.caches:
|
||||
cache.clear()
|
||||
|
||||
|
||||
class GlmImageAttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
|
||||
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
|
||||
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
kv_cache: Optional[GlmImageLayerKVCache] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = encoder_hidden_states.dtype
|
||||
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query).to(dtype=dtype)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key).to(dtype=dtype)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query[:, text_seq_length:, :, :] = apply_rotary_emb(
|
||||
query[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2
|
||||
)
|
||||
key[:, text_seq_length:, :, :] = apply_rotary_emb(
|
||||
key[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2
|
||||
)
|
||||
|
||||
if kv_cache is not None:
|
||||
if kv_cache.mode == "write":
|
||||
kv_cache.store(key, value)
|
||||
elif kv_cache.mode == "read":
|
||||
key, value = kv_cache.get(key, value)
|
||||
elif kv_cache.mode == "skip":
|
||||
pass
|
||||
|
||||
# 4. Attention
|
||||
if attention_mask is not None:
|
||||
text_attn_mask = attention_mask
|
||||
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
||||
text_attn_mask = text_attn_mask.float().to(query.device)
|
||||
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
|
||||
mix_attn_mask[:, :text_seq_length] = text_attn_mask
|
||||
mix_attn_mask = mix_attn_mask.unsqueeze(2)
|
||||
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
|
||||
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 5. Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class GlmImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2560,
|
||||
num_attention_heads: int = 64,
|
||||
attention_head_dim: int = 40,
|
||||
time_embed_dim: int = 512,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Attention
|
||||
self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
qk_norm="layer_norm",
|
||||
elementwise_affine=False,
|
||||
eps=1e-5,
|
||||
processor=GlmImageAttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Feedforward
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
kv_cache: Optional[GlmImageLayerKVCache] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
norm_hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
norm_encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
) = self.norm1(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# 2. Attention
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
kv_cache=kv_cache,
|
||||
**attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
||||
|
||||
# 3. Feedforward
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
|
||||
1 + c_scale_mlp.unsqueeze(1)
|
||||
) + c_shift_mlp.unsqueeze(1)
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output_context = self.ff(norm_encoder_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class GlmImageRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
dim_h, dim_w = self.dim // 2, self.dim // 2
|
||||
h_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
|
||||
)
|
||||
w_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
|
||||
)
|
||||
h_seq = torch.arange(height)
|
||||
w_seq = torch.arange(width)
|
||||
freqs_h = torch.outer(h_seq, h_inv_freq)
|
||||
freqs_w = torch.outer(w_seq, w_inv_freq)
|
||||
|
||||
# Create position matrices for height and width
|
||||
# [height, 1, dim//4] and [1, width, dim//4]
|
||||
freqs_h = freqs_h.unsqueeze(1)
|
||||
freqs_w = freqs_w.unsqueeze(0)
|
||||
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
|
||||
freqs_h = freqs_h.expand(height, width, -1)
|
||||
freqs_w = freqs_w.expand(height, width, -1)
|
||||
|
||||
# Concatenate along last dimension to get [height, width, dim//2]
|
||||
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
||||
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
|
||||
freqs = freqs.reshape(height * width, -1)
|
||||
return (freqs.cos(), freqs.sin())
|
||||
|
||||
|
||||
class GlmImageAdaLayerNormContinuous(nn.Module):
|
||||
"""
|
||||
GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
|
||||
Linear on conditioning embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
# *** NO SiLU here ***
|
||||
emb = self.linear(conditioning_embedding.to(x.dtype))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
r"""
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
attention_head_dim (`int`, defaults to `40`):
|
||||
The number of channels in each head.
|
||||
num_attention_heads (`int`, defaults to `64`):
|
||||
The number of heads to use for multi-head attention.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_embed_dim (`int`, defaults to `1472`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
condition_dim (`int`, defaults to `256`):
|
||||
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
|
||||
crop_coords).
|
||||
pos_embed_max_size (`int`, defaults to `128`):
|
||||
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
|
||||
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
|
||||
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
|
||||
patch_size => 128 * 8 * 2 => 2048`.
|
||||
sample_size (`int`, defaults to `128`):
|
||||
The base resolution of input latents. If height/width is not provided during generation, this value is used
|
||||
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"GlmImageTransformerBlock",
|
||||
"GlmImageImageProjector",
|
||||
"GlmImageImageProjector",
|
||||
]
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
|
||||
_skip_keys = ["kv_caches"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
num_layers: int = 30,
|
||||
attention_head_dim: int = 40,
|
||||
num_attention_heads: int = 64,
|
||||
text_embed_dim: int = 1472,
|
||||
time_embed_dim: int = 512,
|
||||
condition_dim: int = 256,
|
||||
prior_vq_quantizer_codebook_size: int = 16384,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
|
||||
# Each of these are sincos embeddings of shape 2 * condition_dim
|
||||
pooled_projection_dim = 2 * 2 * condition_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels
|
||||
|
||||
# 1. RoPE
|
||||
self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
|
||||
|
||||
# 2. Patch & Text-timestep embedding
|
||||
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
|
||||
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
|
||||
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
|
||||
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
|
||||
|
||||
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim=time_embed_dim,
|
||||
condition_dim=condition_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
timesteps_dim=time_embed_dim,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output projection
|
||||
self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
prior_token_id: torch.Tensor,
|
||||
prior_token_drop: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
kv_caches: Optional[GlmImageKVCache] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. RoPE
|
||||
if image_rotary_emb is None:
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Patch & Timestep embeddings
|
||||
p = self.config.patch_size
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
hidden_states = self.image_projector(hidden_states)
|
||||
encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
|
||||
prior_embedding = self.prior_token_embedding(prior_token_id)
|
||||
prior_embedding[prior_token_drop] *= 0.0
|
||||
prior_hidden_states = self.prior_projector(prior_embedding)
|
||||
|
||||
hidden_states = hidden_states + prior_hidden_states
|
||||
|
||||
temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
|
||||
|
||||
# 3. Transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
kv_caches[idx] if kv_caches is not None else None,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
kv_cache=kv_caches[idx] if kv_caches is not None else None,
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
|
||||
|
||||
# Rearrange tensor from (B, H_p, W_p, C, p, p) to (B, C, H_p * p, W_p * p)
|
||||
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -68,10 +68,6 @@ class MellonParam:
|
||||
def image_latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="image_latents", label="Image Latents", type="latents", display=display)
|
||||
|
||||
@classmethod
|
||||
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display)
|
||||
|
||||
@classmethod
|
||||
def image_latents_with_strength(cls) -> "MellonParam":
|
||||
return cls(
|
||||
@@ -93,10 +89,6 @@ class MellonParam:
|
||||
def embeddings(cls, display: str = "output") -> "MellonParam":
|
||||
return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display)
|
||||
|
||||
@classmethod
|
||||
def image_embeds(cls, display: str = "output") -> "MellonParam":
|
||||
return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display)
|
||||
|
||||
@classmethod
|
||||
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
|
||||
return cls(
|
||||
@@ -180,10 +172,6 @@ class MellonParam:
|
||||
def num_frames(cls, default: int = 81) -> "MellonParam":
|
||||
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
|
||||
|
||||
@classmethod
|
||||
def layers(cls, default: int = 4) -> "MellonParam":
|
||||
return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider")
|
||||
|
||||
@classmethod
|
||||
def videos(cls) -> "MellonParam":
|
||||
return cls(name="videos", label="Videos", type="video", display="output")
|
||||
@@ -198,16 +186,6 @@ class MellonParam:
|
||||
"""
|
||||
return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def image_encoder(cls) -> "MellonParam":
|
||||
"""
|
||||
Image Encoder model info dict.
|
||||
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def unet(cls) -> "MellonParam":
|
||||
"""
|
||||
|
||||
@@ -84,7 +84,7 @@ class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "vae_encoder"]
|
||||
block_names = ["image_resize", "vae_image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -142,7 +142,7 @@ class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -203,7 +203,7 @@ class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
## vae encoder
|
||||
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
@@ -251,7 +251,7 @@ class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_encoder",
|
||||
"vae_image_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
@@ -353,7 +353,7 @@ class Wan22AutoBlocks(SequentialPipelineBlocks):
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_encoder",
|
||||
"vae_image_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
@@ -384,7 +384,7 @@ IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("image_encoder", WanImage2VideoImageEncoderStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
@@ -401,7 +401,7 @@ FLF2V_BLOCKS = InsertableDict(
|
||||
("image_resize", WanImageResizeStep),
|
||||
("last_image_resize", WanImageCropResizeStep),
|
||||
("image_encoder", WanFLF2VImageEncoderStep),
|
||||
("vae_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
@@ -416,7 +416,7 @@ AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("image_encoder", WanAutoImageEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
@@ -438,7 +438,7 @@ TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
@@ -450,7 +450,7 @@ IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
AUTO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", Wan22AutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
|
||||
@@ -15,7 +15,6 @@ from ..utils import (
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -435,8 +434,6 @@ else:
|
||||
"QwenImageLayeredPipeline",
|
||||
]
|
||||
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
|
||||
_import_structure["glm_image"] = ["GlmImagePipeline"]
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -679,7 +676,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
from .hunyuan_video import (
|
||||
|
||||
@@ -52,7 +52,6 @@ from .flux import (
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
)
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
@@ -169,7 +168,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("chroma", ChromaPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("glm_image", GlmImagePipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]}
|
||||
|
||||
# Import transformers components so they can be resolved during pipeline loading
|
||||
|
||||
if is_transformers_available() and is_transformers_version(">=", "4.57.4"):
|
||||
try:
|
||||
from transformers import GlmImageForConditionalGeneration, GlmImageProcessor
|
||||
|
||||
_additional_imports["GlmImageForConditionalGeneration"] = GlmImageForConditionalGeneration
|
||||
_additional_imports["GlmImageProcessor"] = GlmImageProcessor
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_glm_image"] = ["GlmImagePipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_glm_image import GlmImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,825 +0,0 @@
|
||||
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, PreTrainedModel, ProcessorMixin, T5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, GlmImageTransformer2DModel
|
||||
from ...models.transformers.transformer_glm_image import GlmImageKVCache
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, is_transformers_version, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import GlmImagePipelineOutput
|
||||
|
||||
|
||||
# Because it's not released in stable as of 13/01/2026. So this is just a proxy.
|
||||
GlmImageProcessor = ProcessorMixin
|
||||
GlmImageForConditionalGeneration = PreTrainedModel
|
||||
if is_transformers_version(">=", "4.57.4"):
|
||||
from transformers import GlmImageForConditionalGeneration, GlmImageProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import GlmImagePipeline
|
||||
|
||||
>>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
>>> image.save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
base_shift: float = 0.25,
|
||||
max_shift: float = 0.75,
|
||||
) -> float:
|
||||
m = (image_seq_len / base_seq_len) ** 0.5
|
||||
mu = m * max_shift + base_shift
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
if not accepts_timesteps and not accepts_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif timesteps is not None and sigmas is None:
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif timesteps is None and sigmas is not None:
|
||||
if not accepts_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class GlmImagePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using GLM-Image.
|
||||
|
||||
This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion
|
||||
transformer) model for image decoding.
|
||||
|
||||
Args:
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
Tokenizer for the text encoder.
|
||||
processor (`AutoProcessor`):
|
||||
Processor for the AR model to handle chat templates and tokenization.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder for glyph embeddings.
|
||||
vision_language_encoder ([`GlmImageForConditionalGeneration`]):
|
||||
The AR model that generates image tokens from text prompts.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
transformer ([`GlmImageTransformer2DModel`]):
|
||||
A text conditioned transformer to denoise the encoded image latents (DiT).
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "vision_language_encoder->text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: ByT5Tokenizer,
|
||||
processor: GlmImageProcessor,
|
||||
text_encoder: T5EncoderModel,
|
||||
vision_language_encoder: GlmImageForConditionalGeneration,
|
||||
vae: AutoencoderKL,
|
||||
transformer: GlmImageTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
text_encoder=text_encoder,
|
||||
vision_language_encoder=vision_language_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer")
|
||||
and self.transformer is not None
|
||||
and hasattr(self.transformer.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_generation_params(
|
||||
image_grid_thw,
|
||||
is_text_to_image: bool,
|
||||
):
|
||||
grid_sizes = []
|
||||
grid_hw = []
|
||||
|
||||
for i in range(image_grid_thw.shape[0]):
|
||||
t, h, w = image_grid_thw[i].tolist()
|
||||
grid_sizes.append(int(h * w))
|
||||
grid_hw.append((int(h), int(w)))
|
||||
|
||||
if not is_text_to_image:
|
||||
max_new_tokens = grid_sizes[-1] + 1
|
||||
large_image_start_offset = 0
|
||||
target_grid_h, target_grid_w = grid_hw[-1]
|
||||
else:
|
||||
total_tokens = sum(grid_sizes)
|
||||
max_new_tokens = total_tokens + 1
|
||||
large_image_start_offset = sum(grid_sizes[1:])
|
||||
target_grid_h, target_grid_w = grid_hw[0]
|
||||
return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w
|
||||
|
||||
@staticmethod
|
||||
def _extract_large_image_tokens(
|
||||
outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int
|
||||
) -> torch.Tensor:
|
||||
generated_tokens = outputs[0][input_length:]
|
||||
large_image_start = large_image_start_offset
|
||||
large_image_end = large_image_start + large_image_tokens
|
||||
return generated_tokens[large_image_start:large_image_end]
|
||||
|
||||
@staticmethod
|
||||
def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor:
|
||||
token_ids = token_ids.view(1, 1, token_h, token_w)
|
||||
token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
token_ids = token_ids.view(1, -1)
|
||||
return token_ids
|
||||
|
||||
def generate_prior_tokens(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
image: Optional[List[PIL.Image.Image]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
is_text_to_image = image is None or len(image) == 0
|
||||
content = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
target_h=height,
|
||||
target_w=width,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
image_grid_thw = inputs.get("image_grid_thw")
|
||||
max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params(
|
||||
image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image
|
||||
)
|
||||
|
||||
prior_token_image_ids = None
|
||||
if image is not None:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], image_grid_thw[:-1]
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, image_grid_thw[:-1]
|
||||
)
|
||||
|
||||
# For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs.
|
||||
# max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS).
|
||||
outputs = self.vision_language_encoder.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w
|
||||
)
|
||||
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
|
||||
|
||||
return prior_token_ids, prior_token_image_ids
|
||||
|
||||
def get_glyph_texts(self, prompt):
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", prompt)
|
||||
+ re.findall(r"“([^“”]*)”", prompt)
|
||||
+ re.findall(r'"([^"]*)"', prompt)
|
||||
+ re.findall(r"「([^「」]*)」", prompt)
|
||||
)
|
||||
return ocr_texts
|
||||
|
||||
def _get_glyph_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 2048,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
glyph_texts = self.get_glyph_texts(prompt)
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts if len(glyph_texts) > 0 else [""],
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
|
||||
return glyph_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 2048,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
max_sequence_length (`int`, defaults to `2048`):
|
||||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = prompt_embeds.size(1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For GLM-Image, negative_prompt must be "" instead of None
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = negative_prompt_embeds.size(1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
if latents is not None:
|
||||
return latents.to(device)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prior_token_ids=None,
|
||||
prior_image_token_ids=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * self.transformer.config.patch_size * 2) != 0
|
||||
or width is not None
|
||||
and width % (self.transformer.config.patch_size * 2) != 0
|
||||
):
|
||||
# GLM-Image uses 32× downsampling, so the image dimensions must be multiples of 32.
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 4} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
if prompt is not None and prior_token_ids is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prior_token_ids is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined."
|
||||
)
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if (prior_token_ids is None and prior_image_token_ids is not None) or (
|
||||
prior_token_ids is not None and prior_image_token_ids is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot forward only one `prior_token_ids`: {prior_token_ids} or `prior_image_token_ids`:"
|
||||
f" {prior_image_token_ids} provided. Please make sure both are provided or neither."
|
||||
)
|
||||
|
||||
if prior_token_ids is not None and prompt_embeds is None:
|
||||
raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
image: Optional[
|
||||
Union[
|
||||
torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
|
||||
]
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 1.5,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prior_token_ids: Optional[torch.FloatTensor] = None,
|
||||
prior_image_token_ids: Optional[torch.Tensor] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 2048,
|
||||
) -> Union[GlmImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. Must contain shape info in the format '<sop>H
|
||||
W<eop>' where H and W are token dimensions (d32). Example: "A beautiful sunset<sop>36 24<eop>"
|
||||
generates a 1152x768 image.
|
||||
image: Optional condition images for image-to-image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels. If not provided, derived from prompt shape info.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels. If not provided, derived from prompt shape info.
|
||||
num_inference_steps (`int`, *optional*, defaults to `50`):
|
||||
The number of denoising steps for DiT.
|
||||
guidance_scale (`float`, *optional*, defaults to `1.5`):
|
||||
Guidance scale for classifier-free guidance.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to `1`):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random generator for reproducibility.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
Output format: "pil", "np", or "latent".
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`GlmImagePipelineOutput`] or `tuple`: Generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prior_token_ids,
|
||||
prior_image_token_ids,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}")
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Preprocess image tokens and prompt tokens
|
||||
if prior_token_ids is None:
|
||||
prior_token_ids, prior_token_image_ids = self.generate_prior_tokens(
|
||||
prompt=prompt[0] if isinstance(prompt, list) else prompt,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 3. Preprocess image
|
||||
if image is not None:
|
||||
preprocessed_condition_images = []
|
||||
for img in image:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
preprocessed_condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
image = preprocessed_condition_images
|
||||
|
||||
# 5. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# 4. Prepare latents and (optional) image kv cache
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_channels_latents=latent_channels,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
|
||||
|
||||
if image is not None:
|
||||
kv_caches.set_mode("write")
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
|
||||
latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids):
|
||||
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
|
||||
# Do not remove.
|
||||
# It would be use to run the reference image through a
|
||||
# forward pass at timestep 0 and keep the KV cache.
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
|
||||
# 6. Prepare additional timestep conditions
|
||||
target_size = (height, width)
|
||||
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
|
||||
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# Prepare timesteps
|
||||
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
|
||||
self.transformer.config.patch_size**2
|
||||
)
|
||||
timesteps = (
|
||||
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1]
|
||||
if timesteps is None
|
||||
else np.array(timesteps)
|
||||
)
|
||||
timesteps = timesteps.astype(np.int64).astype(np.float32)
|
||||
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("base_shift", 0.25),
|
||||
self.scheduler.config.get("max_shift", 0.75),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool)
|
||||
prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if image is not None:
|
||||
kv_caches.set_mode("read")
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
prior_token_id=prior_token_ids,
|
||||
prior_token_drop=prior_token_drop_cond,
|
||||
timestep=timestep,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_caches=kv_caches,
|
||||
)[0].float()
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if image is not None:
|
||||
kv_caches.set_mode("skip")
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
prior_token_id=prior_token_ids,
|
||||
prior_token_drop=prior_token_drop_uncond,
|
||||
timestep=timestep,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_caches=kv_caches,
|
||||
)[0].float()
|
||||
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
kv_caches.clear()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return GlmImagePipelineOutput(images=image)
|
||||
@@ -1,21 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlmImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for CogView3 pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -60,6 +60,7 @@ from ..utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bitsandbytes_version,
|
||||
is_hpu_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
@@ -444,7 +445,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
||||
|
||||
if is_loaded_in_8bit_bnb:
|
||||
if is_loaded_in_8bit_bnb and (
|
||||
is_bitsandbytes_version("<", "0.48.0") or is_accelerate_version("<", "1.13.0.dev0")
|
||||
):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and (
|
||||
@@ -523,9 +526,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
|
||||
)
|
||||
|
||||
if is_loaded_in_8bit_bnb and device is not None:
|
||||
if is_loaded_in_8bit_bnb and device is not None and is_bitsandbytes_version("<", "0.48.0"):
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
|
||||
"You need to upgrade bitsandbytes to at least 0.48.0"
|
||||
)
|
||||
|
||||
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
|
||||
@@ -542,6 +546,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
|
||||
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
|
||||
module.to(device=device)
|
||||
# added here https://github.com/huggingface/transformers/pull/43258
|
||||
if (
|
||||
is_loaded_in_8bit_bnb
|
||||
and device is not None
|
||||
and is_transformers_version(">", "4.58.0")
|
||||
and is_bitsandbytes_version(">=", "0.48.0")
|
||||
):
|
||||
module.to(device=device)
|
||||
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
|
||||
module.to(device, dtype)
|
||||
|
||||
@@ -1223,7 +1235,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# This is because the model would already be placed on a CUDA device.
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
|
||||
if is_loaded_in_8bit_bnb:
|
||||
if is_loaded_in_8bit_bnb and (
|
||||
is_transformers_version("<", "4.58.0") or is_bitsandbytes_version("<", "0.48.0")
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
|
||||
)
|
||||
|
||||
@@ -982,21 +982,6 @@ class FluxTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class GlmImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1142,21 +1142,6 @@ class FluxPriorReduxPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class GlmImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HiDreamImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ...testing_utils import enable_full_determinism, require_torch_accelerator, require_transformers_version_greater
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_transformers_version(">=", "4.57.4"):
|
||||
from transformers import GlmImageConfig, GlmImageForConditionalGeneration, GlmImageProcessor
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_transformers_version_greater("4.57.4")
|
||||
@require_torch_accelerator
|
||||
class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = GlmImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_attention_slicing = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
glm_config = GlmImageConfig(
|
||||
text_config={
|
||||
"vocab_size": 168064,
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"max_position_embeddings": 512,
|
||||
"vision_vocab_size": 128,
|
||||
"rope_parameters": {"mrope_section": (4, 2, 2)},
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 32,
|
||||
"num_heads": 2,
|
||||
"image_size": 32,
|
||||
"patch_size": 8,
|
||||
"intermediate_size": 32,
|
||||
},
|
||||
vq_config={"embed_dim": 32, "num_embeddings": 128, "latent_channels": 32},
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vision_language_encoder = GlmImageForConditionalGeneration(glm_config)
|
||||
|
||||
# TODO: move to a public checkpoint
|
||||
processor = GlmImageProcessor.from_pretrained("ZP2Test/GLM-Image", subfolder="processor")
|
||||
|
||||
torch.manual_seed(0)
|
||||
# For GLM-Image, the relationship between components must satisfy:
|
||||
# patch_size × vae_scale_factor = 16 (since AR tokens are upsampled 2× from d32)
|
||||
transformer = GlmImageTransformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=2,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=2,
|
||||
text_embed_dim=text_encoder.config.hidden_size,
|
||||
time_embed_dim=16,
|
||||
condition_dim=8,
|
||||
prior_vq_quantizer_codebook_size=128,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=(4, 8, 16, 16),
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=4,
|
||||
sample_size=128,
|
||||
latents_mean=[0.0] * 4,
|
||||
latents_std=[1.0] * 4,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
components = {
|
||||
"tokenizer": tokenizer,
|
||||
"processor": processor,
|
||||
"text_encoder": text_encoder,
|
||||
"vision_language_encoder": vision_language_encoder,
|
||||
"vae": vae,
|
||||
"transformer": transformer,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": "A photo of a cat",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images[0]
|
||||
generated_slice = image.flatten()
|
||||
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertEqual(image.shape, (3, 32, 32))
|
||||
self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_consistent(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_num_images_per_prompt(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_pipeline_level_group_offloading_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Follow set of tests are relaxed because this pipeline doesn't guarantee same outputs for the same inputs in consecutive runs."
|
||||
)
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_cpu_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_float16_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_save_load_float16(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_save_load_local(self):
|
||||
pass
|
||||
@@ -288,31 +288,29 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
|
||||
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||
|
||||
@require_bitsandbytes_version_greater("0.48.0")
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `str`
|
||||
self.model_8bit.to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
self.model_8bit.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.to(torch.device(f"{torch_device}:0"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
# Tries with a `dtype`
|
||||
self.model_8bit.half()
|
||||
|
||||
# This should work with 0.48.0
|
||||
self.model_8bit.to("cpu")
|
||||
self.model_8bit.to(torch.device(f"{torch_device}:0"))
|
||||
|
||||
# Test if we did not break anything
|
||||
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
|
||||
input_dict_for_transformer = self.get_dummy_inputs()
|
||||
@@ -837,7 +835,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
@require_bitsandbytes_version_greater("0.45.5")
|
||||
@require_bitsandbytes_version_greater("0.48.0")
|
||||
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -848,7 +846,7 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
|
||||
reason="Test fails because of a type change when recompiling."
|
||||
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
|
||||
)
|
||||
def test_torch_compile(self):
|
||||
@@ -858,6 +856,5 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
|
||||
|
||||
Reference in New Issue
Block a user