Compare commits

...

7 Commits

Author SHA1 Message Date
yiyixuxu
759cab86fc style 2024-10-05 00:21:07 +02:00
yiyixuxu
9fbd0283be update tranformer + conversion script 2024-10-05 00:11:45 +02:00
zR
eeba2997a1 update 2024-10-05 00:03:34 +08:00
zR
a8f8d7ba96 using with DDPM 2024-10-04 19:51:02 +08:00
zR
cefca0f677 right way 2024-10-04 19:24:37 +08:00
zR
21439e248f draft of transformers convert 2024-10-03 23:20:29 +08:00
zR
b3cadb8c34 draft of cogview3plus 2024-10-02 21:49:06 +08:00
14 changed files with 1474 additions and 5 deletions

View File

@@ -0,0 +1,183 @@
"""
Convert a CogView3 checkpoint to the Diffusers format.
This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
with the Diffusers library.
Example usage:
python scripts/convert_cogview3_to_diffusers.py \
--original_state_dict_repo_id "THUDM/cogview3" \
--filename "cogview3.pt" \
--transformer \
--output_path "./cogview3_diffusers" \
--dtype "bf16"
Alternatively, if you have a local checkpoint:
python scripts/convert_cogview3_to_diffusers.py \
--checkpoint_path '/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
--transformer \
--output_path "/raid/yiyi/cogview3_diffusers" \
--dtype "bf16"
Arguments:
--original_state_dict_repo_id: The Hugging Face repo ID containing the original checkpoint.
--filename: The filename of the checkpoint in the repo (default: "flux.safetensors").
--checkpoint_path: Path to a local checkpoint file (alternative to repo_id and filename).
--transformer: Flag to convert the transformer model.
--output_path: The path to save the converted model.
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32").
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
"""
import argparse
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from diffusers import CogView3PlusTransformer2DModel
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str)
parser.add_argument("--dtype", type=str, default="bf16")
args = parser.parse_args()
def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = torch.load(ckpt_path, map_location="cpu")
return original_state_dict
# this is specific to `AdaLayerNormContinuous`:
# diffusers imnplementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_cogview3_transformer_checkpoint_to_diffusers(original_state_dict):
new_state_dict = {}
# Convert pos_embed
new_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
new_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
new_state_dict["pos_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
new_state_dict["pos_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
# Convert time_text_embed
new_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
new_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_embed.0.bias")
new_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
new_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_embed.2.bias")
new_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop("label_emb.0.0.weight")
new_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop("label_emb.0.0.bias")
new_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop("label_emb.0.2.weight")
new_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop("label_emb.0.2.bias")
# Convert transformer blocks
for i in range(30):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"transformer.layers.{i}."
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
new_state_dict[block_prefix + "attn.to_q.weight"] = q
new_state_dict[block_prefix + "attn.to_q.bias"] = q_bias
new_state_dict[block_prefix + "attn.to_k.weight"] = k
new_state_dict[block_prefix + "attn.to_k.bias"] = k_bias
new_state_dict[block_prefix + "attn.to_v.weight"] = v
new_state_dict[block_prefix + "attn.to_v.bias"] = v_bias
new_state_dict[block_prefix + "attn.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
)
new_state_dict[block_prefix + "attn.to_out.0.bias"] = original_state_dict.pop(
old_prefix + "attention.dense.bias"
)
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.bias"
)
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_4h_to_h.weight"
)
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
# Convert final norm and projection
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
)
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
return new_state_dict
def main(args):
original_ckpt = load_original_checkpoint(args)
original_ckpt = original_ckpt["module"]
original_ckpt = {k.replace("model.diffusion_model.", ""): v for k, v in original_ckpt.items()}
original_dtype = next(iter(original_ckpt.values())).dtype
dtype = None
if args.dtype is None:
dtype = original_dtype
elif args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
if args.transformer:
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(original_ckpt)
transformer = CogView3PlusTransformer2DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
print(f"Saving CogView3 Transformer in Diffusers format in {args.output_path}/transformer")
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
if len(original_ckpt) > 0:
print(f"Warning: {len(original_ckpt)} keys were not converted and will be saved as is: {original_ckpt.keys()}")
if __name__ == "__main__":
main(args)

View File

@@ -84,6 +84,7 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
@@ -258,6 +259,7 @@ else:
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
@@ -558,6 +560,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
@@ -710,6 +713,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,

View File

@@ -54,6 +54,7 @@ if is_torch_available():
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -98,6 +99,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
FluxTransformer2DModel,

View File

@@ -122,6 +122,7 @@ class Attention(nn.Module):
out_dim: int = None,
context_pre_only=None,
pre_only=False,
layrnorm_elementwise_affine: bool = True,
):
super().__init__()
@@ -179,8 +180,8 @@ class Attention(nn.Module):
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)

View File

@@ -714,6 +714,58 @@ class FluxPosEmbed(nn.Module):
return freqs_cos, freqs_sin
class CogView3PlusPatchEmbed(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
text_hidden_size: int = 4096,
pos_embed_max_size: int = 128,
):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.pos_embed_max_size = pos_embed_max_size
# Linear projection for image patches
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
if height % self.patch_size != 0 or width % self.patch_size != 0:
raise ValueError("Height and width must be divisible by patch size")
height = height // self.patch_size
width = width // self.patch_size
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
# Project the patches
hidden_states = self.proj(hidden_states)
encoder_hidden_states = self.text_proj(encoder_hidden_states)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# Calculate text_length
text_length = encoder_hidden_states.shape[1]
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
text_pos_embed = torch.zeros(
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
)
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
return (hidden_states + pos_embed).to(hidden_states.dtype)
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -1018,6 +1070,27 @@ class IPAdapterFaceIDImageProjection(nn.Module):
return self.norm(x)
class CogView3CombineTimestepLabelEmbedding(nn.Module):
def __init__(self, time_embed_dim, label_embed_dim, in_channels=2560):
super().__init__()
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=in_channels, time_embed_dim=time_embed_dim)
self.label_embedder = nn.Sequential(
nn.Linear(label_embed_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timestep, class_labels, hidden_dtype=None):
t_proj = self.time_proj(timestep)
t_emb = self.timestep_embedder(t_proj.to(dtype=hidden_dtype))
label_emb = self.label_embedder(class_labels)
emb = t_emb + label_emb
return emb
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
@@ -1038,11 +1111,11 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
def __init__(self, embedding_dim, pooled_projection_dim, timesteps_dim=256):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, pooled_projection):

View File

@@ -355,6 +355,51 @@ class LuminaLayerNormContinuous(nn.Module):
return x
class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, dim: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
c_shift_msa,
c_scale_msa,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = emb.chunk(12, dim=1)
normed_x = self.norm_x(x)
normed_context = self.norm_c(context)
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,

View File

@@ -14,6 +14,7 @@ if is_torch_available():
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel

View File

@@ -0,0 +1,364 @@
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
from ..embeddings import CogView3PlusPatchEmbed, CombinedTimestepTextProjEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogView3PlusTransformerBlock(nn.Module):
"""
Updated CogView3 Transformer Block to align with AdalnAttentionMixin style, simplified with qk_ln always True.
"""
def __init__(
self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
):
super().__init__()
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
self.attn = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
out_dim=dim,
bias=True,
qk_norm="layer_norm",
layrnorm_elementwise_affine=False,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
emb: torch.Tensor,
text_length: int,
) -> torch.Tensor:
encoder_hidden_states, hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
# norm1
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
norm_encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = self.norm1(hidden_states, encoder_hidden_states, emb)
# Attention
attn_input = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
attn_output = self.attn(hidden_states=attn_input)
context_attn_output, attn_output = attn_output[:, :text_length], attn_output[:, text_length:]
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
norm_hidden_states = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
ff_output = self.ff(norm_hidden_states)
context_ff_output, ff_output = ff_output[:, :text_length], ff_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
hidden_states = torch.cat((encoder_hidden_states, hidden_states), dim=1)
return hidden_states
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in CogView3.
Reference: https://arxiv.org/abs/2403.05121
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
out_channels: int = 16,
encoder_hidden_states_dim: int = 4096,
pooled_projection_dim: int = 1536,
pos_embed_max_size: int = 128,
time_embed_dim: int = 512,
):
super().__init__()
self.out_channels = out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = CogView3PlusPatchEmbed(
in_channels=self.config.in_channels,
hidden_size=self.inner_dim,
patch_size=self.config.patch_size,
text_hidden_size=self.config.encoder_hidden_states_dim,
pos_embed_max_size=self.config.pos_embed_max_size,
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.config.time_embed_dim,
pooled_projection_dim=self.config.pooled_projection_dim,
timesteps_dim=self.inner_dim,
)
self.transformer_blocks = nn.ModuleList(
[
CogView3PlusTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
time_embed_dim=self.config.time_embed_dim,
)
for _ in range(self.config.num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
embedding_dim=self.inner_dim,
conditioning_embedding_dim=self.config.time_embed_dim,
elementwise_affine=False,
eps=1e-6,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedJointAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor`): Input `hidden_states`.
timestep (`torch.LongTensor`): Indicates denoising step.
y (`torch.LongTensor`, *optional*): 标签输入,用于获取标签嵌入。
block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors for residuals.
joint_attention_kwargs (`dict`, *optional*): Additional kwargs for the attention processor.
return_dict (`bool`, *optional*, defaults to `True`): Whether to return a `Transformer2DModelOutput`.
Returns:
Output tensor or `Transformer2DModelOutput`.
"""
height, width = hidden_states.shape[-2:]
text_length = encoder_hidden_states.shape[1]
hidden_states = self.pos_embed(
hidden_states, encoder_hidden_states
) # takes care of adding positional embeddings too.
emb = self.time_text_embed(timestep, pooled_projections)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
emb,
text_length,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
emb=emb,
text_length=text_length,
)
hidden_states = hidden_states[:, text_length:]
hidden_states = self.norm_out(hidden_states, emb)
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
)
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -145,6 +145,9 @@ else:
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
]
_import_structure["cogview3"] = [
"CogView3PlusPipeline",
]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -469,6 +472,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .cogview3 import CogView3PlusPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,

View File

@@ -20,6 +20,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
@@ -118,6 +119,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux", FluxPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline),
]
)

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_cogview3plus import CogView3PlusPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,707 @@
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers.transformer_cogview3plus import CogView3PlusTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CogView3PipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import CogView3PlusPipeline
>>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
>>> image.save("cat.png")
```
"""
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogView3PlusPipeline(DiffusionPipeline, PeftAdapterMixin, FromOriginalModelMixin):
r"""
The CogView3 pipeline for text-to-image generation.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
transformer ([`CogView3PlusTransformerBlock`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: CogView3PlusTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)[0].to(
dtype=dtype, device=device
)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, PeftAdapterMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, PeftAdapterMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used instead
negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
`text_encoder_3`. If not defined, `negative_prompt` is used instead
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.cogview3.CogView3PlusPipelineOutput`] or `tuple`:
[`~pipelines.cogview3.CogView3PlusPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return CogView3PipelineOutput(images=image)

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class CogView3PipelineOutput(BaseOutput):
"""
Output class for CogView3 pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -422,6 +422,21 @@ class FluxPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class CogView3PlusPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]