Compare commits

...

10 Commits

Author SHA1 Message Date
patil-suraj
1e3d230904 frame wise guidance 2024-04-02 14:47:51 +02:00
patil-suraj
b8f650bac3 pass interpolation_scale to temporal pos embeds 2024-04-02 12:05:52 +02:00
patil-suraj
6eca4655eb fix conversion 2024-04-02 11:34:13 +02:00
patil-suraj
04d18a2669 get rid of sigma before cfg 2024-04-02 11:14:38 +02:00
patil-suraj
80d9ddb061 make pipeline work 2024-04-02 10:11:58 +02:00
patil-suraj
6c825c9497 add pipeline 2024-04-01 23:03:44 +02:00
patil-suraj
64c542a0ac fix loading weights 2024-04-01 22:02:03 +02:00
patil-suraj
04c4d39738 rename temporal pos embeds 2024-04-01 22:01:19 +02:00
patil-suraj
10fe7eeb8a use safetensors to load weights 2024-04-01 21:58:39 +02:00
patil-suraj
9682a04624 add STDiT 2024-04-01 21:49:57 +02:00
11 changed files with 1698 additions and 0 deletions

View File

@@ -0,0 +1,172 @@
import argparse
import os
import torch
from safetensors import safe_open
from diffusers import Transformer3DModel
ckpt_id = "PixArt-alpha/PixArt-alpha"
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
interpolation_scale = {256: 0.5, 512: 1}
def main(args):
state_dict = {}
with safe_open(args.orig_ckpt_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
converted_state_dict = {}
# Patch embeddings.
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
# Caption projection.
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Shared norm.
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")
for depth in range(28):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Attention is all you need 🤘
# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)
# Temporal attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn_temp.qkv.weight"), 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn_temp.qkv.bias"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_v.bias"] = v_bias
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn_temp.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn_temp.proj.bias"
)
# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.bias"
)
# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)
# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
converted_state_dict["pos_embed_temporal"] = state_dict.pop("pos_embed_temporal")
# DiT XL/2
transformer = Transformer3DModel(
sample_size=(16, args.image_size // 8, args.image_size // 8),
patch_size=(1, 2, 2),
num_layers=28,
attention_head_dim=72,
num_attention_heads=16,
in_channels=4,
out_channels=8,
cross_attention_dim=1152,
num_embeds_ada_norm=1000,
norm_eps=1e-6,
caption_channels=4096,
)
transformer.load_state_dict(converted_state_dict, strict=True)
assert transformer.pos_embed.pos_embed is not None
state_dict.pop("pos_embed")
state_dict.pop("y_embedder.y_embedding")
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=256,
type=int,
choices=[256, 512],
required=False,
help="Image size of pretrained model, either 256 or 512.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
args = parser.parse_args()
main(args)

View File

@@ -90,6 +90,7 @@ else:
"T2IAdapter",
"T5FilmDecoder",
"Transformer2DModel",
"Transformer3DModel",
"UNet1DModel",
"UNet2DConditionModel",
"UNet2DModel",
@@ -256,6 +257,7 @@ else:
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"MusicLDMPipeline",
"OpenSoraPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
@@ -483,6 +485,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
Transformer3DModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
@@ -628,6 +631,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
MusicLDMPipeline,
OpenSoraPipeline,
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,

View File

@@ -38,6 +38,7 @@ if is_torch_available():
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_3d"] = ["Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
@@ -75,6 +76,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PriorTransformer,
T5FilmDecoder,
Transformer2DModel,
Transformer3DModel,
TransformerTemporalModel,
)
from .unets import (

View File

@@ -187,6 +187,75 @@ class PatchEmbed(nn.Module):
return (latent + pos_embed).to(latent.dtype)
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=(1, 2, 2),
in_channels=3,
embed_dim=768,
layer_norm=False,
bias=True,
interpolation_scale=1,
):
super().__init__()
num_patches = (height // patch_size[1]) * (width // patch_size[2])
self.layer_norm = layer_norm
self.emed_dim = embed_dim
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size[1], width // patch_size[2]
self.base_size = height // patch_size[1]
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size[1], latent.shape[-1] // self.patch_size[2]
latent = self.proj(latent) # (B C T H W)
if self.layer_norm:
batch_size, _, num_frames, height, width = latent.size()
latent = latent.flatten(2).transpose(1, 2)
latent = self.norm(latent)
latent = latent.transpose(1, 2).view(batch_size, self.emed_dim, num_frames, height, width)
latent = latent.flatten(3).permute(0, 2, 3, 1) # BCTHW -> BT(HW)C
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
latent = (latent + pos_embed).to(latent.dtype)
latent = latent.flatten(1, 2) # BT(H*W)C -> B(T*H*W)C
return latent
class TimestepEmbedding(nn.Module):
def __init__(
self,

View File

@@ -6,4 +6,5 @@ if is_torch_available():
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_3d import Transformer3DModel
from .transformer_temporal import TransformerTemporalModel

View File

@@ -0,0 +1,434 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward, _chunked_feed_forward
from ..attention_processor import Attention
from ..embeddings import PatchEmbed3D, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Transformer3DModelOutput(BaseOutput):
"""
The output of [`Transformer3DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
pos = np.arange(0, length)[..., None] / scale
return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
@maybe_allow_in_graph
class Transformer3DBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: int,
num_embeds_ada_norm: Optional[int] = None,
norm_eps: float = 1e-6,
num_temporal_patches: int = 16,
num_spatial_patches: int = 256,
):
super().__init__()
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_single = True
self.num_embeds_ada_norm = num_embeds_ada_norm
self.num_temporal_patches = num_temporal_patches
self.num_spatial_patches = num_spatial_patches
# Define 3 blocks. Each block has its own normalization layer.
# 1. Spatial Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=True,
)
# 2. Temporal Self-Attn
self.attn_temporal = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=True,
)
# 2. Cross-Attn
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(dim, norm_eps, elementwise_affine=False)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=True,
)
# 3. Feed-forward
self.ff = FeedForward(dim, activation_fn="gelu-approximate")
# 4. Scale-shift for PixArt-Alpha.
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
timestep: torch.LongTensor,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
temporal_pos_embed: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
# 1. Spatial Self-Attention
# reshape (batch, num_temporal_patches*num_spatial_patches, dim) -> (batch * num_temporal_patches, num_spatial_patches, dim)
norm_hidden_states = norm_hidden_states.view(
batch_size, self.num_temporal_patches, self.num_spatial_patches, -1
)
norm_hidden_states = norm_hidden_states.view(
batch_size * self.num_temporal_patches, self.num_spatial_patches, -1
)
attn_output = self.attn1(norm_hidden_states)
# reshape (batch * num_temporal_patches, num_spatial_patches, dim) -> (batch, num_temporal_patches*num_spatial_patches, dim)
attn_output = attn_output.view(batch_size, self.num_temporal_patches, self.num_spatial_patches, -1)
attn_output = attn_output.view(batch_size, self.num_temporal_patches * self.num_spatial_patches, -1)
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2. Temporal Self-Attention
# reshape (batch, num_temporal_patches*num_spatial_patches, dim) -> (batch * num_spatial_patches, num_temporal_patches, dim)
temporal_hidden_states = (
hidden_states.view(batch_size, self.num_temporal_patches, self.num_spatial_patches, -1)
.transpose(1, 2)
.contiguous()
)
temporal_hidden_states = temporal_hidden_states.view(
batch_size * self.num_spatial_patches, self.num_temporal_patches, -1
)
if temporal_pos_embed is not None:
temporal_hidden_states = temporal_hidden_states + temporal_pos_embed
attn_output = self.attn_temporal(temporal_hidden_states)
# reshape (batch * num_spatial_patches, num_temporal_patches, dim) -> (batch, num_temporal_patches*num_spatial_patches, dim)
attn_output = (
attn_output.view(batch_size, self.num_spatial_patches, self.num_temporal_patches, -1)
.transpose(1, 2)
.contiguous()
)
attn_output = attn_output.view(batch_size, self.num_temporal_patches * self.num_spatial_patches, -1)
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 3. Cross-Attention
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Transformer3DModel(ModelMixin, ConfigMixin):
"""
A 3D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
sample_size: Tuple[int] = (2, 4, 4),
patch_size: Tuple[int] = (1, 2, 2),
in_channels: int = 4,
out_channels: int = 8,
num_layers: int = 1,
cross_attention_dim: int = 256,
num_embeds_ada_norm: int = 1000,
norm_eps: float = 1e-6,
caption_channels: int = 256,
interpolation_scale: float = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 2. Define input layers
self.height = sample_size[1]
self.width = sample_size[2]
self.num_patches = np.prod([sample_size[i] // patch_size[i] for i in range(3)])
self.num_temporal_patches = sample_size[0] // patch_size[0]
self.num_spatial_patches = self.num_patches // self.num_temporal_patches
self.patch_size = patch_size
interpolation_scale = (
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size[1] // 64, 1)
)
self.pos_embed = PatchEmbed3D(
height=self.height,
width=self.width,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
Transformer3DBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
num_embeds_ada_norm=num_embeds_ada_norm,
norm_eps=norm_eps,
num_temporal_patches=self.num_temporal_patches,
num_spatial_patches=self.num_spatial_patches,
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.proj_out = nn.Linear(inner_dim, np.prod(patch_size) * self.out_channels)
# 5. PixArt-Alpha blocks.
self.adaln_single = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
interpolation_scale = max(self.config.sample_size[0] // 16, 1)
temporal_pos_embed = get_1d_sincos_pos_embed(inner_dim, self.num_temporal_patches, scale=interpolation_scale)
temporal_pos_embed = torch.from_numpy(temporal_pos_embed).float().unsqueeze(0).requires_grad_(False)
self.register_buffer("pos_embed_temporal", temporal_pos_embed)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
height, width = hidden_states.shape[-2] // self.patch_size[1], hidden_states.shape[-1] // self.patch_size[2]
# import ipdb; ipdb.set_trace()
hidden_states = self.pos_embed(hidden_states)
batch_size = hidden_states.shape[0]
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for i, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
temporal_pos_embed=self.pos_embed_temporal if i == 0 else None,
)
# 3. Output
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
self.num_temporal_patches,
height,
width,
self.patch_size[1],
self.patch_size[2],
self.out_channels,
)
)
hidden_states = torch.einsum("nthwpqc->ncthpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.out_channels,
self.num_temporal_patches,
height * self.patch_size[1],
width * self.patch_size[2],
)
)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)

View File

@@ -247,6 +247,7 @@ else:
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
_import_structure["open_sora"] = ["OpenSoraPipeline"]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -438,6 +439,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .musicldm import MusicLDMPipeline
from .open_sora import OpenSoraPipeline
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_open_sora"] = ["OpenSoraPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_open_sora import OpenSoraPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,912 @@
# Copyright 2024 Open-Sora Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import re
import urllib.parse as ul
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer3DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
BaseOutput,
deprecate,
is_bs4_available,
is_ftfy_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import OpenSoraPipeline
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
>>> pipe = OpenSoraPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A small cactus with a happy face in the Sahara desert."
>>> image = pipe(prompt).images[0]
```
"""
@dataclass
class OpenSoraPipelineOutput(BaseOutput):
r"""
Output class for Stable Video Diffusion pipeline.
Args:
frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.FloatTensor`]):
List of denoised PIL images of length `batch_size` or numpy array or torch tensor
of shape `(batch_size, num_frames, height, width, num_channels)`.
"""
frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.FloatTensor]
def _append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class OpenSoraPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Open-Sora.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Open-Sora uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/Open-Sora/Open-Sora/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
bad_punct_regex = re.compile(
r"["
+ "#®•©™&@·º½¾¿¡§~"
+ r"\)"
+ r"\("
+ r"\]"
+ r"\["
+ r"\}"
+ r"\{"
+ r"\|"
+ "\\"
+ r"\/"
+ r"\*"
+ r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: Transformer3DModel,
scheduler: DPMSolverMultistepScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from https://github.com/Open-Sora/Open-Sora/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index
else:
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 120,
**kwargs,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
Open-Sora, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For Open-Sora, it's should be the embeddings of the ""
string.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
"""
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = max_sequence_length
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0]
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
def prepare_latents(
self,
batch_size: int,
num_frames: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: Union[str, torch.device],
generator: torch.Generator,
latents: Optional[torch.FloatTensor] = None,
):
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
# [batch, channels, frames, height, width] -> [batch, frames, channels, height, width]
latents = latents.permute(0, 2, 1, 3, 4)
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
latents = latents.flatten(0, 1)
latents = 1 / self.vae.config.scaling_factor * latents
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
# decode decode_chunk_size frames at a time to avoid OOM
frames = []
for i in range(0, latents.shape[0], decode_chunk_size):
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
decode_kwargs = {}
if accepts_num_frames:
# we only pass num_frames_in if it's expected
decode_kwargs["num_frames"] = num_frames_in
frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
frames.append(frame)
frames = torch.cat(frames, dim=0)
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
frames = frames.float()
return frames
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
guidance_scale: float = 4.5,
min_guidance_scale: Optional[float] = None,
max_guidance_scale: Optional[float] = None,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
max_sequence_length: int = 120,
decode_chunk_size: int = 4,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# 1. Check inputs. Raise error if not correct
num_frames = self.transformer.config.sample_size[0]
height = height or self.transformer.config.sample_size[1] * self.vae_scale_factor
width = width or self.transformer.config.sample_size[2] * self.vae_scale_factor
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_frames,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8. Prepare guidance scale
# TODO: Hacky for testing, make this cleaner
if min_guidance_scale is not None and max_guidance_scale is not None:
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
guidance_scale = guidance_scale.to(device, latents.dtype)
guidance_scale = guidance_scale.repeat(batch_size * num_images_per_prompt, 1)
guidance_scale = _append_dims(guidance_scale, latents.ndim).transpose(1, 2)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
return_dict=False,
)[0]
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
else:
frames = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (frames,)
return OpenSoraPipelineOutput(frames=frames)

View File

@@ -227,6 +227,21 @@ class Transformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class Transformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch"]

39
test.py Normal file
View File

@@ -0,0 +1,39 @@
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import Transformer3DModel, AutoencoderTiny, DPMSolverMultistepScheduler, OpenSoraPipeline
channels, num_frames, height, width, text_dim = 4, 2, 4, 4, 32
model = Transformer3DModel(
in_channels=channels,
out_channels=channels*2,
cross_attention_dim=1408,
caption_channels=text_dim,
num_embeds_ada_norm=1000,
sample_size=(num_frames, height, width),
)
x = torch.randn(1, channels, num_frames, height, width)
y = torch.randn(1, 77, 32)
t = torch.ones(1)
# with torch.no_grad():
# out = model(x, y, t)
# print(out.sample.shape) # torch.Size([1, 8, 2, 4, 4])
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
vae = AutoencoderTiny.from_pretrained("madebyollin/taesd")
scheduler = DPMSolverMultistepScheduler.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="scheduler")
pipe = OpenSoraPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=model,
scheduler=scheduler,
)
prompt = ""
out = pipe(prompt, num_inference_steps=1, min_guidance_scale=1.0, max_guidance_scale=3.0)