Compare commits

...

14 Commits

Author SHA1 Message Date
yiyixuxu
cfe19a31b9 up 2026-01-08 03:15:17 +01:00
zRzRzRzRzRzRzR
170d0ba160 remove useless func 2026-01-08 00:04:24 +08:00
zRzRzRzRzRzRzR
1cf277d36d remove sop 2026-01-07 23:51:40 +08:00
zRzRzRzRzRzRzR
e2b31f8b15 Update pipeline_glm_image.py 2026-01-07 22:55:16 +08:00
Yuxuan Zhang
acd13d8769 Merge branch 'huggingface:main' into cogview 2026-01-07 19:00:28 +08:00
zRzRzRzRzRzRzR
b3d1b5547b merge2pipeline 2026-01-07 18:59:30 +08:00
zRzRzRzRzRzRzR
22fe6c9023 init with encoder 2026-01-07 18:15:01 +08:00
zRzRzRzRzRzRzR
ec678a1fb7 update 2026-01-07 17:16:39 +08:00
zRzRzRzRzRzRzR
adcc53206b 2 2026-01-07 17:07:13 +08:00
zRzRzRzRzRzRzR
e13fb76552 rename 2026-01-07 16:55:02 +08:00
zRzRzRzRzRzRzR
bcc9c303f6 Update __init__.py 2026-01-07 15:59:47 +08:00
zRzRzRzRzRzRzR
57fd26d8fe add 1 2026-01-07 15:41:55 +08:00
zRzRzRzRzRzRzR
b98decfe5f add 2026-01-07 15:35:56 +08:00
zRzRzRzRzRzRzR
ec9a82fc3f init 2026-01-07 15:31:11 +08:00
13 changed files with 1625 additions and 0 deletions

View File

@@ -353,6 +353,8 @@
title: Flux2Transformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/glm_image_transformer2d
title: GlmImageTransformer2DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -541,6 +543,8 @@
title: Flux2
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/glm_image
title: GLM-Image
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit

View File

@@ -0,0 +1,18 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# GlmImageTransformer2DModel
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]()
## GlmImageTransformer2DModel
[[autodoc]] GlmImageTransformer2DModel

View File

@@ -0,0 +1,31 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# GLM-Image
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org).
## GlmImagePipeline
[[autodoc]] GlmImagePipeline
- all
- __call__
## GlmImagePipelineOutput
[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput

View File

@@ -223,6 +223,7 @@ else:
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"GlmImageTransformer2DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -487,6 +488,7 @@ else:
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"GlmImagePipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -969,6 +971,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -1203,6 +1206,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
GlmImagePipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,

View File

@@ -96,6 +96,7 @@ if is_torch_available():
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -203,6 +204,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,

View File

@@ -1658,6 +1658,37 @@ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
return conditioning
class GlmImageCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(
self,
timestep: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
hidden_dtype: torch.dtype,
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
# (B, 2 * condition_dim)
condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
conditioning = timesteps_emb + condition_emb
return conditioning
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6

View File

@@ -27,6 +27,7 @@ if is_torch_available():
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_glm_image import GlmImageTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel

View File

@@ -0,0 +1,567 @@
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import GlmImageCombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class GlmImageImageProjector(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
post_patch_height = height // self.patch_size
post_patch_width = width // self.patch_size
hidden_states = hidden_states.reshape(
batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
hidden_states = self.proj(hidden_states)
return hidden_states
class GlmImageAdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, dim: int) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = hidden_states.dtype
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
emb = self.linear(temb)
(
shift_msa,
c_shift_msa,
scale_msa,
c_scale_msa,
gate_msa,
c_gate_msa,
shift_mlp,
c_shift_mlp,
scale_mlp,
c_scale_mlp,
gate_mlp,
c_gate_mlp,
) = emb.chunk(12, dim=1)
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
return (
hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
)
class GlmImageLayerKVCache:
"""KV cache for GlmImage model."""
def __init__(self):
self.k_cache = None
self.v_cache = None
self.mode: Optional[str] = None # "write", "read", "skip"
def store(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache is None:
self.k_cache = k
self.v_cache = v
else:
self.k_cache = torch.cat([self.k_cache, k], dim=2)
self.v_cache = torch.cat([self.v_cache, v], dim=2)
def get(self):
return self.k_cache, self.v_cache
def clear(self):
self.k_cache = None
self.v_cache = None
self.mode = None
class GlmImageKVCache:
"""Container for all layers' KV caches."""
def __init__(self, num_layers: int):
self.num_layers = num_layers
self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]
def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:
return self.caches[layer_idx]
def set_mode(self, mode: Optional[str]):
if mode is not None and mode not in ["write", "read", "skip"]:
raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'")
for cache in self.caches:
cache.mode = mode
def clear(self):
for cache in self.caches:
cache.clear()
class GlmImageAttnProcessor:
"""
Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
kv_cache: Optional[GlmImageLayerKVCache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = encoder_hidden_states.dtype
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key).to(dtype=dtype)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
if kv_cache is not None:
if kv_cache.mode == "write":
kv_cache.store(key, value)
elif kv_cache.mode == "read":
k_cache, v_cache = kv_cache.get()
key = torch.cat([k_cache, key], dim=2) if k_cache is not None else key
value = torch.cat([v_cache, value], dim=2) if v_cache is not None else value
elif kv_cache.mode == "skip":
pass
# 4. Attention
if attention_mask is not None:
text_attn_mask = attention_mask
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
text_attn_mask = text_attn_mask.float().to(query.device)
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
mix_attn_mask[:, :text_seq_length] = text_attn_mask
mix_attn_mask = mix_attn_mask.unsqueeze(2)
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class GlmImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
) -> None:
super().__init__()
# 1. Attention
self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
out_dim=dim,
bias=True,
qk_norm="layer_norm",
elementwise_affine=False,
eps=1e-5,
processor=GlmImageAttnProcessor(),
)
# 2. Feedforward
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
kv_cache: Optional[GlmImageLayerKVCache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
norm_encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = self.norm1(hidden_states, encoder_hidden_states, temb)
# 2. Attention
if attention_kwargs is None:
attention_kwargs = {}
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
kv_cache=kv_cache,
**attention_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
# 3. Feedforward
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
1 + c_scale_mlp.unsqueeze(1)
) + c_shift_mlp.unsqueeze(1)
ff_output = self.ff(norm_hidden_states)
ff_output_context = self.ff(norm_encoder_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
return hidden_states, encoder_hidden_states
class GlmImageRotaryPosEmbed(nn.Module):
def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_channels, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
dim_h, dim_w = self.dim // 2, self.dim // 2
h_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
)
w_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
)
h_seq = torch.arange(height)
w_seq = torch.arange(width)
freqs_h = torch.outer(h_seq, h_inv_freq)
freqs_w = torch.outer(w_seq, w_inv_freq)
# Create position matrices for height and width
# [height, 1, dim//4] and [1, width, dim//4]
freqs_h = freqs_h.unsqueeze(1)
freqs_w = freqs_w.unsqueeze(0)
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
freqs_h = freqs_h.expand(height, width, -1)
freqs_w = freqs_w.expand(height, width, -1)
# Concatenate along last dimension to get [height, width, dim//2]
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
freqs = freqs.reshape(height * width, -1)
return (freqs.cos(), freqs.sin())
class GlmImageAdaLayerNormContinuous(nn.Module):
"""
GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
Linear on conditioning embedding.
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
norm_type: str = "layer_norm",
):
super().__init__()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# *** NO SiLU here ***
emb = self.linear(conditioning_embedding.to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `40`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `64`):
The number of heads to use for multi-head attention.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_embed_dim (`int`, defaults to `1472`):
Input dimension of text embeddings from the text encoder.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
condition_dim (`int`, defaults to `256`):
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
crop_coords).
pos_embed_max_size (`int`, defaults to `128`):
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
patch_size => 128 * 8 * 2 => 2048`.
sample_size (`int`, defaults to `128`):
The base resolution of input latents. If height/width is not provided during generation, this value is used
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
"""
_supports_gradient_checkpointing = True
_no_split_modules = [
"GlmImageTransformerBlock",
"GlmImageImageProjector",
"GlmImageImageProjector",
]
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: int = 16,
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
text_embed_dim: int = 1472,
time_embed_dim: int = 512,
condition_dim: int = 256,
prior_vq_quantizer_codebook_size: int = 16384,
):
super().__init__()
# GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
# Each of these are sincos embeddings of shape 2 * condition_dim
pooled_projection_dim = 2 * 2 * condition_dim
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels
# 1. RoPE
self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
# 2. Patch & Text-timestep embedding
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
pooled_projection_dim=pooled_projection_dim,
timesteps_dim=time_embed_dim,
)
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
for _ in range(num_layers)
]
)
# 4. Output projection
self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
prior_token_id: torch.Tensor,
prior_token_drop: torch.Tensor,
timestep: torch.LongTensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
kv_caches: Optional[GlmImageKVCache] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
if image_rotary_emb is None:
image_rotary_emb = self.rope(hidden_states)
# 2. Patch & Timestep embeddings
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
hidden_states = self.image_projector(hidden_states)
encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
prior_embedding = self.prior_token_embedding(prior_token_id)
prior_embedding[prior_token_drop] *= 0.0
prior_hidden_states = self.prior_projector(prior_embedding)
hidden_states = hidden_states + prior_hidden_states
temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
temb = F.silu(temb)
# 3. Transformer blocks
for idx, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
kv_caches[idx] if kv_caches is not None else None,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
kv_cache=kv_caches[idx] if kv_caches is not None else None,
)
# 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -52,6 +52,7 @@ from .flux import (
FluxKontextPipeline,
FluxPipeline,
)
from .glm_image import GlmImagePipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -167,6 +168,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("chroma", ChromaPipeline),
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("glm_image", GlmImagePipeline),
("cogview4-control", CogView4ControlPipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_glm_image"] = ["GlmImagePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_glm_image import GlmImagePipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,882 @@
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from math import sqrt
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from transformers import ByT5Tokenizer, GlmImageForConditionalGeneration, GlmImageProcessor, T5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, GlmImageTransformer2DModel
from ...models.transformers.transformer_glm_image import GlmImageKVCache
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import GlmImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import GlmImagePipeline
>>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A photo of an astronaut riding a horse on mars<sop>36 24<eop>"
>>> image = pipe(prompt).images[0]
>>> image.save("output.png")
```
"""
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
base_shift: float = 0.25,
max_shift: float = 0.75,
) -> float:
m = (image_seq_len / base_seq_len) ** 0.5
mu = m * max_shift + base_shift
return mu
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
"""
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if timesteps is not None and sigmas is not None:
if not accepts_timesteps and not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif timesteps is not None and sigmas is None:
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif timesteps is None and sigmas is not None:
if not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using GLM-Image.
This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion
transformer) model for image decoding.
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder for glyph embeddings.
tokenizer (`PreTrainedTokenizer`):
Tokenizer for the text encoder.
processor (`AutoProcessor`):
Processor for the AR model to handle chat templates and tokenization.
vision_language_encoder ([`GlmImageForConditionalGeneration`]):
The AR model that generates image tokens from text prompts.
transformer ([`GlmImageTransformer2DModel`]):
A text conditioned transformer to denoise the encoded image latents (DiT).
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
_optional_components = []
model_cpu_offload_seq = "transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
tokenizer: ByT5Tokenizer,
processor: GlmImageProcessor,
text_encoder: T5EncoderModel,
vision_language_encoder: GlmImageForConditionalGeneration,
vae: AutoencoderKL,
transformer: GlmImageTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
processor=processor,
text_encoder=text_encoder,
vision_language_encoder=vision_language_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer")
and self.transformer is not None
and hasattr(self.transformer.config, "sample_size")
else 128
)
def _build_image_grid_thw(
self,
token_h: int,
token_w: int,
prev_token_h: int,
prev_token_w: int,
existing_grid: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if existing_grid is None or existing_grid.numel() == 0:
return torch.tensor(
[
[1, token_h, token_w],
[1, prev_token_h, prev_token_w],
],
device=device,
)
else:
return torch.cat([existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], dim=0)
def _calculate_ar_generation_params(
self, token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool
) -> Tuple[int, int]:
"""
Calculate max_new_tokens and large_image_start_offset for AR generation.
"""
large_image_tokens = token_h * token_w
small_image_tokens = prev_token_h * prev_token_w
if is_text_to_image:
max_new_tokens = small_image_tokens + large_image_tokens + 1
large_image_start_offset = small_image_tokens
else:
max_new_tokens = large_image_tokens + 1
large_image_start_offset = 0
return max_new_tokens, large_image_start_offset
def _extract_large_image_tokens(
self, outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int
) -> torch.Tensor:
generated_tokens = outputs[0][input_length:]
large_image_start = large_image_start_offset
large_image_end = large_image_start + large_image_tokens
return generated_tokens[large_image_start:large_image_end]
def _upsample_d32_to_d16(self, token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor:
"""
Upsample token IDs from d32 format to d16 format.
AR model generates tokens at d32 resolution (each token = 32x32 pixels). DiT expects tokens at d16 resolution
(each token = 16x16 pixels). This function performs 2x nearest-neighbor upsampling.
Args:
token_ids: Token IDs of shape [N] where N = token_h * token_w
token_h: Height in d32 token units
token_w: Width in d32 token units
Returns:
Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2)
"""
token_ids = token_ids.view(1, 1, token_h, token_w)
token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(
dtype=torch.long
)
token_ids = token_ids.view(1, -1)
return token_ids
def _build_prompt_with_shape(
self,
prompt: str,
height: int,
width: int,
is_text_to_image: bool,
factor: int = 32,
) -> Tuple[str, int, int, int, int]:
"""
Build prompt with shape info (<sop>H W<eop>) based on height and width.
Args:
prompt: The raw text prompt without shape info
height: Target image height in pixels
width: Target image width in pixels
is_text_to_image: Whether this is text-to-image (True) or image-to-image (False)
Returns:
Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w)
"""
token_h = height // factor
token_w = width // factor
ratio = token_h / token_w
prev_token_h = int(sqrt(ratio) * (factor // 2))
prev_token_w = int(sqrt(1 / ratio) * (factor // 2))
if is_text_to_image:
expanded_prompt = f"{prompt}<sop>{token_h} {token_w}<eop><sop>{prev_token_h} {prev_token_w}<eop>"
else:
expanded_prompt = f"{prompt}<sop>{token_h} {token_w}<eop>"
return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w
def generate_prior_tokens(
self,
prompt: str,
height: int,
width: int,
image: Optional[List[PIL.Image.Image]] = None,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
"""
Generate prior tokens using the AR (vision_language_encoder) model.
Automatically builds the prompt with shape info based on height/width. Users only need to provide the raw text
prompt without <sop>...<eop> tags.
Args:
prompt: The raw text prompt (without shape info)
height: Target image height in pixels (must be divisible by factor)
width: Target image width in pixels (must be divisible by factor)
image: Optional list of condition images for image-to-image generation
factor: Token size factor (32 for d32 tokens)
Returns:
Tuple of (prior_token_ids, pixel_height, pixel_width)
- prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4]
- pixel_height: Image height in pixels (aligned to factor)
- pixel_width: Image width in pixels (aligned to factor)
"""
device = self.vision_language_encoder.device
height = (height // factor) * factor
width = (width // factor) * factor
is_text_to_image = image is None or len(image) == 0
expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape(
prompt, height, width, is_text_to_image
)
content = []
if image is not None:
for img in image:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": expanded_prompt})
messages = [{"role": "user", "content": content}]
inputs = self.processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
)
existing_grid = inputs.get("image_grid_thw")
inputs["image_grid_thw"] = self._build_image_grid_thw(
token_h,
token_w,
prev_h,
prev_w,
existing_grid=existing_grid if not is_text_to_image else None,
device=device,
)
max_new_tokens, large_image_offset = self._calculate_ar_generation_params(
token_h, token_w, prev_h, prev_w, is_text_to_image
)
large_image_tokens = token_h * token_w
inputs = inputs.to(device)
input_length = inputs["input_ids"].shape[-1]
outputs = self.vision_language_encoder.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
)
prior_token_ids_d32 = self._extract_large_image_tokens(
outputs, input_length, large_image_offset, large_image_tokens
)
prior_token_ids = self._upsample_d32_to_d16(prior_token_ids_d32, token_h, token_w)
pixel_height = token_h * factor
pixel_width = token_w * factor
return prior_token_ids, pixel_height, pixel_width
def get_glyph_texts(self, prompt):
prompt = prompt[0] if isinstance(prompt, list) else prompt
ocr_texts = (
re.findall(r"'([^']*)'", prompt)
+ re.findall(r"“([^“”]*)”", prompt)
+ re.findall(r'"([^"]*)"', prompt)
+ re.findall(r"「([^「」]*)」", prompt)
)
return ocr_texts
def _get_glyph_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 2048,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
glyph_texts = self.get_glyph_texts(prompt)
input_ids = self.tokenizer(
glyph_texts if len(glyph_texts) > 0 else [""],
max_length=max_sequence_length,
truncation=True,
).input_ids
input_ids = [
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
]
max_length = max(len(input_ids_) for input_ids_ in input_ids)
attention_mask = torch.tensor(
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
)
input_ids = torch.tensor(
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
device=device,
)
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
return glyph_embeds.to(device=device, dtype=dtype)
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 2048,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
max_sequence_length (`int`, defaults to `2048`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
seq_len = prompt_embeds.size(1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_embeds = None
if do_classifier_free_guidance:
negative_prompt = ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
seq_len = negative_prompt_embeds.size(1)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def check_inputs(
self,
prompt,
height,
width,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
):
if (
height is not None
and height % (self.vae_scale_factor * self.transformer.config.patch_size) != 0
or width is not None
and width % (self.transformer.config.patch_size) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Optional[
Union[
torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
]
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.5,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 2048,
) -> Union[GlmImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. Must contain shape info in the format '<sop>H
W<eop>' where H and W are token dimensions (d32). Example: "A beautiful sunset<sop>36 24<eop>"
generates a 1152x768 image.
image: Optional condition images for image-to-image generation.
height (`int`, *optional*):
The height in pixels. If not provided, derived from prompt shape info.
width (`int`, *optional*):
The width in pixels. If not provided, derived from prompt shape info.
num_inference_steps (`int`, *optional*, defaults to `50`):
The number of denoising steps for DiT.
guidance_scale (`float`, *optional*, defaults to `1.5`):
Guidance scale for classifier-free guidance.
num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
Random generator for reproducibility.
output_type (`str`, *optional*, defaults to `"pil"`):
Output format: "pil", "np", or "latent".
Examples:
Returns:
[`GlmImagePipelineOutput`] or `tuple`: Generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs
self.check_inputs(
prompt,
height,
width,
callback_on_step_end_tensor_inputs,
prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
assert batch_size == 1, "batch_size must be 1"
device = self._execution_device
ar_condition_images = None
if image is not None:
if not isinstance(image, list):
image = [image]
ar_condition_images = []
for img in image:
if isinstance(img, PIL.Image.Image):
ar_condition_images.append(img)
elif isinstance(img, torch.Tensor):
img_np = img.cpu().numpy()
if img_np.ndim == 4:
img_np = img_np[0]
if img_np.shape[0] in [1, 3, 4]:
img_np = img_np.transpose(1, 2, 0)
if img_np.max() <= 1.0:
img_np = (img_np * 255).astype(np.uint8)
ar_condition_images.append(PIL.Image.fromarray(img_np))
else:
ar_condition_images.append(PIL.Image.fromarray(img))
prior_token_id, ar_height, ar_width = self.generate_prior_tokens(
prompt=prompt[0] if isinstance(prompt, list) else prompt,
image=ar_condition_images,
height=height,
width=width,
)
height = height or ar_height
width = width or ar_width
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
dtype=self.dtype,
)
# 4. process images
condition_images_prior_token_id = None
if image is not None:
preprocessed_condition_images = []
condition_images_prior_token_id = []
for img in image:
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
image_height = (image_height // multiple_of) * multiple_of
image_width = (image_width // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
preprocessed_condition_images.append(img)
image = preprocessed_condition_images
# 5. Prepare latents and (optional) image kv cache
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=latent_channels,
height=height,
width=width,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
if image is not None and condition_images_prior_token_id is not None:
kv_caches.set_mode("write")
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(self.vae.device, self.vae.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(self.vae.device, self.vae.dtype)
)
empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...]
for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id):
condition_image = condition_image.to(device=device, dtype=self.vae.dtype)
condition_latent = retrieve_latents(
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
)
condition_latent = (condition_latent - latents_mean) / latents_std
_ = self.transformer(
hidden_states=condition_latent,
encoder_hidden_states=empty_glyph_hiddens,
prior_token_id=condition_image_prior_token_id,
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
timestep=torch.zeros((1,), device=device),
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
crop_coords=torch.zeros((1, 2), device=device),
attention_kwargs=attention_kwargs,
kv_caches=kv_caches,
)
# 6. Prepare additional timestep conditions
target_size = (height, width)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
# Prepare timesteps
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
)
timesteps = (
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1]
if timesteps is None
else np.array(timesteps)
)
timesteps = timesteps.astype(np.int64).astype(np.float32)
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
)
self._num_timesteps = len(timesteps)
# 7. Denoising loop
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool)
prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0]) - 1
if image is not None:
kv_caches.set_mode("read")
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop_cond,
timestep=timestep,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
kv_caches=kv_caches,
)[0].float()
# perform guidance
if self.do_classifier_free_guidance:
if image is not None:
kv_caches.set_mode("skip")
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop_uncond,
timestep=timestep,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
kv_caches=kv_caches,
)[0].float()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
kv_caches.clear()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents = latents * latents_std + latents_mean
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return GlmImagePipelineOutput(images=image)

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class GlmImagePipelineOutput(BaseOutput):
"""
Output class for CogView3 pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -967,6 +967,21 @@ class HiDreamImageTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class GlmImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
_backends = ["torch"]