mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-25 21:15:57 +08:00
Compare commits
6 Commits
modular-te
...
yiyi-refac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b177ff9c6 | ||
|
|
437cc9d734 | ||
|
|
5faf4e93f7 | ||
|
|
800c3cc28d | ||
|
|
b06f152224 | ||
|
|
3bdf428309 |
@@ -15,6 +15,7 @@ from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaVideoCausalTransformer3DModel,
|
||||
SanaVideoPipeline,
|
||||
SanaVideoTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
@@ -24,7 +25,10 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
|
||||
"Efficient-Large-Model/Sana-Video_2B_480p_LongLive/checkpoints/SANA_Video_2B_480p_LongLive.pth",
|
||||
]
|
||||
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
|
||||
|
||||
|
||||
@@ -98,6 +102,10 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"Video size {args.video_size} is not supported.")
|
||||
|
||||
use_causal_linear_attn = False
|
||||
if "Sana-Video_2B_480p_LongLive" in file_path:
|
||||
use_causal_linear_attn = True
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
@@ -201,7 +209,10 @@ def main(args):
|
||||
"rope_max_seq_len": 1024,
|
||||
}
|
||||
|
||||
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
|
||||
if use_causal_linear_attn:
|
||||
transformer = SanaVideoCausalTransformer3DModel(**transformer_kwargs)
|
||||
else:
|
||||
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
|
||||
|
||||
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
|
||||
|
||||
@@ -314,7 +325,7 @@ if __name__ == "__main__":
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
|
||||
parser.add_argument("--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v.")
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
@@ -252,6 +252,7 @@ else:
|
||||
"SanaControlNetModel",
|
||||
"SanaTransformer2DModel",
|
||||
"SanaVideoTransformer3DModel",
|
||||
"SanaVideoCausalTransformer3DModel",
|
||||
"SD3ControlNetModel",
|
||||
"SD3MultiControlNetModel",
|
||||
"SD3Transformer2DModel",
|
||||
@@ -559,7 +560,7 @@ else:
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"LongSanaVideoPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -974,6 +975,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageTransformer2DModel,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoCausalTransformer3DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
@@ -1215,6 +1217,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LongSanaVideoPipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
|
||||
@@ -108,6 +108,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sana_video_causal"] = ["SanaVideoCausalTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
@@ -217,6 +218,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PRXTransformer2DModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoCausalTransformer3DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3Transformer2DModel,
|
||||
SkyReelsV2Transformer3DModel,
|
||||
|
||||
@@ -40,6 +40,7 @@ if is_torch_available():
|
||||
from .transformer_prx import PRXTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sana_video import SanaVideoTransformer3DModel
|
||||
from .transformer_sana_video_causal import SanaVideoCausalTransformer3DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
@@ -94,6 +94,98 @@ class GLUMBTempConv(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CachedGLUMBConvTemp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: Optional[str] = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
self.conv_temp = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
save_kv_cache: bool = False,
|
||||
kv_cache: Optional[list] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
|
||||
"""
|
||||
hidden_states: shape [B, T, H, W, C]
|
||||
kv_cache: list, with kv_cache[0/1/2] for optional cached states (only kv_cache[2] is used here for temporal)
|
||||
"""
|
||||
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, num_frames, height, width, num_channels = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size * num_frames, height, width, num_channels).permute(0, 3, 1, 2)
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
# Temporal aggregation with kv_cache support
|
||||
hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
padding_size = self.conv_temp.kernel_size[0] // 2 # usually 1
|
||||
hidden_states_temporal_in = hidden_states_temporal
|
||||
padded_size = 0
|
||||
|
||||
# If using cache, prepend cached frames from last chunk along time axis (dim 2)
|
||||
if kv_cache is not None:
|
||||
if len(kv_cache) < 3:
|
||||
kv_cache.extend([None] * (3 - len(kv_cache)))
|
||||
if kv_cache[2] is not None:
|
||||
hidden_states_temporal_in = torch.cat([kv_cache[2], hidden_states_temporal], dim=2)
|
||||
padded_size = kv_cache[2].shape[2]
|
||||
# Save last padding_size frames for next chunk
|
||||
if save_kv_cache:
|
||||
kv_cache[2] = hidden_states_temporal[:, :, -padding_size:, :].detach().clone()
|
||||
else:
|
||||
if save_kv_cache:
|
||||
kv_cache = [None, None, hidden_states_temporal[:, :, -padding_size:, :].detach().clone()]
|
||||
|
||||
t_conv_out = self.conv_temp(hidden_states_temporal_in)[:, :, padded_size:]
|
||||
hidden_states = hidden_states_temporal + t_conv_out
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
if kv_cache is not None or save_kv_cache:
|
||||
return hidden_states, kv_cache
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaLinearAttnProcessor3_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product linear attention.
|
||||
@@ -172,6 +264,121 @@ class SanaLinearAttnProcessor3_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaCausalLinearAttnProcessor1_0:
|
||||
r"""
|
||||
Processor for implementing causal linear attention with KV cache support.
|
||||
Designed for autoregressive generation scenarios.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
save_kv_cache: bool = False,
|
||||
kv_cache: Optional[list] = None,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
# Project input to query, key, value
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
# Apply normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Reshape to multi-head format: B, N, C -> B, N, H, C
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# Apply lightweight linear attention kernel (ReLU)
|
||||
query = F.relu(query)
|
||||
key = F.relu(key)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x1 * cos - x2 * sin
|
||||
out[..., 1::2] = x1 * sin + x2 * cos
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
query_rotate = apply_rotary_emb(query, *rotary_emb)
|
||||
key_rotate = apply_rotary_emb(key, *rotary_emb)
|
||||
|
||||
# Permute to attention computation format: B, N, H, C -> B, H, C, N
|
||||
query = query.permute(0, 2, 3, 1)
|
||||
key = key.permute(0, 2, 3, 1)
|
||||
query_rotate = query_rotate.permute(0, 2, 3, 1)
|
||||
key_rotate = key_rotate.permute(0, 2, 3, 1)
|
||||
value = value.permute(0, 2, 3, 1)
|
||||
|
||||
# Cast to float for numerical stability
|
||||
query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float()
|
||||
|
||||
# Linear attention computation with KV cache support
|
||||
# Compute key sum for normalization: sum over sequence dimension
|
||||
k_sum = key.sum(dim=-1, keepdim=True).transpose(-2, -1) # B, H, 1, C
|
||||
|
||||
# Compute value-key product: V @ K^T
|
||||
scores = torch.matmul(value, key_rotate.transpose(-1, -2)) # B, H, C, C
|
||||
|
||||
# Handle KV cache for autoregressive generation
|
||||
if kv_cache is not None:
|
||||
cached_vk, cached_k_sum = kv_cache[0], kv_cache[1]
|
||||
|
||||
# Save current step's KV to cache if requested
|
||||
if save_kv_cache:
|
||||
kv_cache[0] = scores.detach().clone()
|
||||
kv_cache[1] = k_sum.detach().clone()
|
||||
|
||||
# Accumulate with previous cached values
|
||||
if cached_vk is not None and cached_k_sum is not None:
|
||||
scores = scores + cached_vk
|
||||
k_sum = k_sum + cached_k_sum
|
||||
|
||||
# Normalization factor: 1 / (K_sum @ Q + epsilon)
|
||||
z = 1 / (k_sum @ query + 1e-15)
|
||||
|
||||
# Final attention output: (V @ K^T) @ Q
|
||||
hidden_states = torch.matmul(scores, query_rotate)
|
||||
|
||||
# Apply normalization
|
||||
hidden_states = hidden_states * z
|
||||
|
||||
# Reshape back: B, H, C, N -> B, N, C
|
||||
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
# Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
# Return with cache if applicable
|
||||
if kv_cache is not None:
|
||||
return hidden_states, kv_cache
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -454,6 +661,146 @@ class SanaVideoTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block with KV cache support for causal linear attention.
|
||||
Used in LongSana for autoregressive generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2240,
|
||||
num_attention_heads: int = 20,
|
||||
attention_head_dim: int = 112,
|
||||
dropout: float = 0.0,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
attention_bias: bool = True,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
attention_out_bias: bool = True,
|
||||
mlp_ratio: float = 3.0,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
self_attn_processor: Optional[nn.Module] = None,
|
||||
ffn_processor: Optional[nn.Module] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention - must use causal linear attention
|
||||
if self_attn_processor is None:
|
||||
self_attn_processor = SanaCausalLinearAttnProcessor1_0()
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
kv_heads=num_attention_heads if qk_norm is not None else None,
|
||||
qk_norm=qk_norm,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
processor=self_attn_processor,
|
||||
)
|
||||
|
||||
# 2. Cross Attention
|
||||
if cross_attention_dim is not None:
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
qk_norm=qk_norm,
|
||||
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_cross_attention_heads,
|
||||
dim_head=cross_attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
out_bias=attention_out_bias,
|
||||
processor=SanaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Feed-forward - must use cached conv
|
||||
if ffn_processor is None:
|
||||
ffn_processor = CachedGLUMBConvTemp
|
||||
self.ff = ffn_processor(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
frames: int = None,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
save_kv_cache: bool = False,
|
||||
kv_cache: Optional[list] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# 1. Modulation
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1)
|
||||
).unbind(dim=2)
|
||||
|
||||
# 2. Self Attention with KV cache
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
# Causal linear attention always supports kv_cache
|
||||
attn_result = self.attn1(
|
||||
norm_hidden_states,
|
||||
rotary_emb=rotary_emb,
|
||||
save_kv_cache=save_kv_cache,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
if isinstance(attn_result, tuple):
|
||||
attn_output, kv_cache = attn_result
|
||||
else:
|
||||
attn_output = attn_result
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_output
|
||||
|
||||
# 3. Cross Attention (no cache)
|
||||
if self.attn2 is not None:
|
||||
attn_output = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward with KV cache
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
|
||||
|
||||
# Cached conv always supports kv_cache
|
||||
ff_result = self.ff(
|
||||
norm_hidden_states,
|
||||
save_kv_cache=save_kv_cache,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
if isinstance(ff_result, tuple):
|
||||
ff_output, kv_cache = ff_result
|
||||
else:
|
||||
ff_output = ff_result
|
||||
|
||||
ff_output = ff_output.flatten(1, 3)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
if kv_cache is not None or save_kv_cache:
|
||||
return hidden_states, kv_cache
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 3D Transformer model introduced in [Sana-Video](https://huggingface.co/papers/2509.24695) family of models.
|
||||
@@ -703,3 +1050,271 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 3D Transformer model with KV cache support for LongSana autoregressive generation.
|
||||
|
||||
This model extends Sana-Video with causal linear attention and cached convolutions
|
||||
to enable efficient long video generation through chunked processing.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `20`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `112`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
num_cross_attention_heads (`int`, *optional*, defaults to `20`):
|
||||
The number of heads to use for cross-attention.
|
||||
cross_attention_head_dim (`int`, *optional*, defaults to `112`):
|
||||
The number of channels in each head for cross-attention.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `2240`):
|
||||
The number of channels in the cross-attention output.
|
||||
caption_channels (`int`, defaults to `2304`):
|
||||
The number of channels in the caption embeddings.
|
||||
mlp_ratio (`float`, defaults to `2.5`):
|
||||
The expansion ratio to use in the GLUMBConv layer.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether to use bias in the attention layer.
|
||||
sample_size (`int`, defaults to `32`):
|
||||
The base size of the input latent.
|
||||
patch_size (`int`, defaults to `1`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
norm_elementwise_affine (`bool`, defaults to `False`):
|
||||
Whether to use elementwise affinity in the normalization layer.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value for the normalization layer.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for the query and key.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SanaVideoCausalTransformerBlock", "SanaModulatedNorm"]
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
num_attention_heads: int = 20,
|
||||
attention_head_dim: int = 112,
|
||||
num_layers: int = 20,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
caption_channels: int = 2304,
|
||||
mlp_ratio: float = 2.5,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = False,
|
||||
sample_size: int = 30,
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
guidance_embeds: bool = False,
|
||||
guidance_embeds_scale: float = 0.1,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
if guidance_embeds:
|
||||
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
|
||||
else:
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
||||
|
||||
# 3. Transformer blocks - use causal versions
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SanaVideoCausalTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
num_cross_attention_heads=num_cross_attention_heads,
|
||||
cross_attention_head_dim=cross_attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qk_norm=qk_norm,
|
||||
self_attn_processor=SanaCausalLinearAttnProcessor1_0(),
|
||||
ffn_processor=CachedGLUMBConvTemp,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output blocks
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
save_kv_cache: bool = False,
|
||||
kv_cache: Optional[list] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
if guidance is not None:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
timestep = timestep.view(batch_size, -1, timestep.size(-1))
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
||||
|
||||
# 2. Transformer blocks with KV cache
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# Note: gradient checkpointing doesn't support kv_cache (requires tuple return)
|
||||
if kv_cache is not None:
|
||||
logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.")
|
||||
kv_cache = None
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
else:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
# Get kv_cache for this block if available
|
||||
block_kv_cache = kv_cache[index_block] if kv_cache is not None else None
|
||||
|
||||
block_result = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
save_kv_cache=save_kv_cache,
|
||||
kv_cache=block_kv_cache,
|
||||
)
|
||||
|
||||
# Handle return value (could be tensor or tuple)
|
||||
if isinstance(block_result, tuple):
|
||||
hidden_states, updated_kv_cache = block_result
|
||||
if kv_cache is not None:
|
||||
kv_cache[index_block] = updated_kv_cache
|
||||
else:
|
||||
hidden_states = block_result
|
||||
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
# 3. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
if kv_cache is not None or save_kv_cache:
|
||||
return (output, kv_cache)
|
||||
return (output,)
|
||||
|
||||
if kv_cache is not None or save_kv_cache:
|
||||
return Transformer2DModelOutput(sample=output), kv_cache
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -0,0 +1,820 @@
|
||||
# Copyright 2025 The HuggingFace Team and SANA-Video Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers, BaseOutput
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class SanaBlockKvCache:
|
||||
vk: Optional[torch.Tensor] = None
|
||||
k_sum: Optional[torch.Tensor] = None
|
||||
temporal_cache: Optional[torch.Tensor] = None
|
||||
_enable_save: bool = False
|
||||
|
||||
def disable_save(self):
|
||||
self._enable_save = False
|
||||
|
||||
def enable_save(self):
|
||||
self._enable_save = True
|
||||
|
||||
def maybe_save(
|
||||
self,
|
||||
vk: Optional[torch.Tensor]=None,
|
||||
k_sum: Optional[torch.Tensor]=None,
|
||||
temporal_cache: Optional[torch.Tensor]=None,
|
||||
):
|
||||
if not self._enable_save:
|
||||
return
|
||||
|
||||
if vk is not None:
|
||||
self.vk = vk.detach().clone()
|
||||
if k_sum is not None:
|
||||
self.k_sum = k_sum.detach().clone()
|
||||
if temporal_cache is not None:
|
||||
self.temporal_cache = temporal_cache.detach().clone()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SanaVideoCausalTransformer3DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`SanaVideoCausalTransformer3DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_frames, height, width, num_channels)`):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input.
|
||||
kv_cache (`SanaKvCache`, *optional*):
|
||||
The KV cache for the transformer blocks.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
kv_caches: Optional[List[SanaBlockKvCache]] = None
|
||||
|
||||
|
||||
class CachedGLUMBConvTemp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: Optional[str] = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
self.conv_temp = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[SanaBlockKvCache] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
|
||||
"""
|
||||
hidden_states: shape [B, T, H, W, C]
|
||||
kv_cache: SanaBlockKvCache, with optional cached states (only temporal_cache is used here for temporal)
|
||||
"""
|
||||
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, num_frames, height, width, num_channels = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size * num_frames, height, width, num_channels).permute(0, 3, 1, 2)
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
# Temporal aggregation with kv_cache support
|
||||
hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
padding_size = self.conv_temp.kernel_size[0] // 2 # usually 1
|
||||
hidden_states_temporal_in = hidden_states_temporal
|
||||
padded_size = 0
|
||||
|
||||
# If using cache, prepend cached frames from last chunk along time axis (dim 2)
|
||||
if kv_cache is not None:
|
||||
if kv_cache.temporal_cache is not None:
|
||||
hidden_states_temporal_in = torch.cat([kv_cache.temporal_cache, hidden_states_temporal], dim=2)
|
||||
padded_size = kv_cache.temporal_cache.shape[2]
|
||||
# Save last padding_size frames for next chunk
|
||||
kv_cache.maybe_save(
|
||||
temporal_cache=hidden_states_temporal[:, :, -padding_size:, :],
|
||||
)
|
||||
|
||||
t_conv_out = self.conv_temp(hidden_states_temporal_in)[:, :, padded_size:]
|
||||
hidden_states = hidden_states_temporal + t_conv_out
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, kv_cache
|
||||
|
||||
|
||||
class SanaCausalLinearAttnProcessor1_0:
|
||||
r"""
|
||||
Processor for implementing causal linear attention with KV cache support.
|
||||
Designed for autoregressive generation scenarios.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
kv_cache: Optional[SanaBlockKvCache] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
# Project input to query, key, value
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
# Apply normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Reshape to multi-head format: B, N, C -> B, N, H, C
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# Apply lightweight linear attention kernel (ReLU)
|
||||
query = F.relu(query)
|
||||
key = F.relu(key)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x1 * cos - x2 * sin
|
||||
out[..., 1::2] = x1 * sin + x2 * cos
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
query_rotate = apply_rotary_emb(query, *rotary_emb)
|
||||
key_rotate = apply_rotary_emb(key, *rotary_emb)
|
||||
|
||||
# Permute to attention computation format: B, N, H, C -> B, H, C, N
|
||||
query = query.permute(0, 2, 3, 1)
|
||||
key = key.permute(0, 2, 3, 1)
|
||||
query_rotate = query_rotate.permute(0, 2, 3, 1)
|
||||
key_rotate = key_rotate.permute(0, 2, 3, 1)
|
||||
value = value.permute(0, 2, 3, 1)
|
||||
|
||||
# Cast to float for numerical stability
|
||||
query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float()
|
||||
|
||||
# Linear attention computation with KV cache support
|
||||
# Compute key sum for normalization: sum over sequence dimension
|
||||
k_sum = key.sum(dim=-1, keepdim=True).transpose(-2, -1) # B, H, 1, C
|
||||
|
||||
# Compute value-key product: V @ K^T
|
||||
scores = torch.matmul(value, key_rotate.transpose(-1, -2)) # B, H, C, C
|
||||
|
||||
# Handle KV cache for autoregressive generation
|
||||
if kv_cache is not None:
|
||||
cached_vk, cached_k_sum = kv_cache.vk, kv_cache.k_sum
|
||||
kv_cache.maybe_save(vk=scores, k_sum=k_sum)
|
||||
if cached_vk is not None and cached_k_sum is not None:
|
||||
scores = scores + cached_vk
|
||||
k_sum = k_sum + cached_k_sum
|
||||
|
||||
# Normalization factor: 1 / (K_sum @ Q + epsilon)
|
||||
z = 1 / (k_sum @ query + 1e-15)
|
||||
|
||||
# Final attention output: (V @ K^T) @ Q
|
||||
hidden_states = torch.matmul(scores, query_rotate)
|
||||
|
||||
# Apply normalization
|
||||
hidden_states = hidden_states * z
|
||||
|
||||
# Reshape back: B, H, C, N -> B, N, C
|
||||
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
# Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states, kv_cache
|
||||
|
||||
|
||||
# Copied from transformers.transformer_sana_video.WanRotaryPosEmbed
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
||||
dim,
|
||||
max_seq_len,
|
||||
theta,
|
||||
use_real=True,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
freqs_cos.append(freq_cos)
|
||||
freqs_sin.append(freq_sin)
|
||||
|
||||
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
||||
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
# Copied from transformers.transformer_sana_video.SanaModulatedNorm
|
||||
class SanaModulatedNorm(nn.Module):
|
||||
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
shift, scale = (scale_shift_table[None, None] + temb[:, :, None].to(scale_shift_table.device)).unbind(dim=2)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.transformer_sana_video.SanaCombinedTimestepGuidanceEmbeddings
|
||||
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.guidance_condition_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
|
||||
conditioning = timesteps_emb + guidance_emb
|
||||
|
||||
return self.linear(self.silu(conditioning)), conditioning
|
||||
|
||||
|
||||
# Copied from transformers.transformer_sana_video.SanaAttnProcessor2_0
|
||||
class SanaAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block with KV cache support for causal linear attention.
|
||||
Used in LongSana for autoregressive generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2240,
|
||||
num_attention_heads: int = 20,
|
||||
attention_head_dim: int = 112,
|
||||
dropout: float = 0.0,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
attention_bias: bool = True,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
attention_out_bias: bool = True,
|
||||
mlp_ratio: float = 3.0,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention - must use causal linear attention
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
kv_heads=num_attention_heads if qk_norm is not None else None,
|
||||
qk_norm=qk_norm,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
processor=SanaCausalLinearAttnProcessor1_0(),
|
||||
)
|
||||
|
||||
# 2. Cross Attention
|
||||
if cross_attention_dim is not None:
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
qk_norm=qk_norm,
|
||||
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_cross_attention_heads,
|
||||
dim_head=cross_attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
out_bias=attention_out_bias,
|
||||
processor=SanaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Feed-forward - must use cached conv
|
||||
self.ff = CachedGLUMBConvTemp(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
frames: int = None,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
kv_cache: Optional[SanaBlockKvCache] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# 1. Modulation
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1)
|
||||
).unbind(dim=2)
|
||||
|
||||
# 2. Self Attention with KV cache
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
# Causal linear attention always supports kv_cache
|
||||
attn_output, kv_cache = self.attn1(
|
||||
norm_hidden_states,
|
||||
rotary_emb=rotary_emb,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
hidden_states = hidden_states + gate_msa * attn_output
|
||||
|
||||
# 3. Cross Attention (no cache)
|
||||
if self.attn2 is not None:
|
||||
attn_output = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward with KV cache
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
|
||||
|
||||
# Cached conv always supports kv_cache
|
||||
ff_output, kv_cache = self.ff(
|
||||
norm_hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
|
||||
ff_output = ff_output.flatten(1, 3)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
return hidden_states, kv_cache
|
||||
|
||||
|
||||
class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 3D Transformer model with KV cache support for LongSana autoregressive generation.
|
||||
|
||||
This model extends Sana-Video with causal linear attention and cached convolutions
|
||||
to enable efficient long video generation through chunked processing.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `20`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `112`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
num_cross_attention_heads (`int`, *optional*, defaults to `20`):
|
||||
The number of heads to use for cross-attention.
|
||||
cross_attention_head_dim (`int`, *optional*, defaults to `112`):
|
||||
The number of channels in each head for cross-attention.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `2240`):
|
||||
The number of channels in the cross-attention output.
|
||||
caption_channels (`int`, defaults to `2304`):
|
||||
The number of channels in the caption embeddings.
|
||||
mlp_ratio (`float`, defaults to `2.5`):
|
||||
The expansion ratio to use in the GLUMBConv layer.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether to use bias in the attention layer.
|
||||
sample_size (`int`, defaults to `32`):
|
||||
The base size of the input latent.
|
||||
patch_size (`int`, defaults to `1`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
norm_elementwise_affine (`bool`, defaults to `False`):
|
||||
Whether to use elementwise affinity in the normalization layer.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value for the normalization layer.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for the query and key.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SanaVideoCausalTransformerBlock", "SanaModulatedNorm"]
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
num_attention_heads: int = 20,
|
||||
attention_head_dim: int = 112,
|
||||
num_layers: int = 20,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
caption_channels: int = 2304,
|
||||
mlp_ratio: float = 2.5,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = False,
|
||||
sample_size: int = 30,
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
guidance_embeds: bool = False,
|
||||
guidance_embeds_scale: float = 0.1,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
if guidance_embeds:
|
||||
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
|
||||
else:
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
||||
|
||||
# 3. Transformer blocks - use causal versions
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SanaVideoCausalTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
num_cross_attention_heads=num_cross_attention_heads,
|
||||
cross_attention_head_dim=cross_attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output blocks
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
kv_caches: Optional[List[SanaBlockKvCache]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], SanaVideoCausalTransformer3DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
if guidance is not None:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
timestep = timestep.view(batch_size, -1, timestep.size(-1))
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
||||
|
||||
# 2. Transformer blocks with KV cache
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# Note: gradient checkpointing doesn't support kv_cache (requires tuple return)
|
||||
if kv_caches is not None:
|
||||
logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.")
|
||||
kv_caches = None
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, _ = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
kv_cache=None,
|
||||
)
|
||||
else:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
# Get kv_cache for this block if available
|
||||
block_kv_cache = kv_caches[index_block] if kv_caches is not None else None
|
||||
|
||||
hidden_states, block_kv_cache = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
kv_cache=block_kv_cache,
|
||||
)
|
||||
|
||||
# Handle return value (could be tensor or tuple)
|
||||
if kv_caches is not None:
|
||||
kv_caches[index_block] = block_kv_cache
|
||||
|
||||
# 3. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output, kv_caches)
|
||||
|
||||
return SanaVideoCausalTransformer3DModelOutput(sample=output, kv_cache=kv_caches)
|
||||
@@ -313,6 +313,7 @@ else:
|
||||
]
|
||||
_import_structure["sana_video"] = [
|
||||
"SanaVideoPipeline",
|
||||
"LongSanaVideoPipeline",
|
||||
"SanaImageToVideoPipeline",
|
||||
]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
@@ -758,7 +759,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
)
|
||||
from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline
|
||||
from .sana_video import LongSanaVideoPipeline, SanaImageToVideoPipeline, SanaVideoPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
|
||||
|
||||
@@ -26,7 +26,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PixArtImageProcessor
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, SanaTransformer2DModel
|
||||
from ...schedulers import DPMSolverMultistepScheduler
|
||||
from ...schedulers import SCMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -156,7 +156,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC,
|
||||
transformer: SanaTransformer2DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
scheduler: SCMScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
|
||||
_import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"]
|
||||
_import_structure["pipeline_longsana"] = ["LongSanaVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_longsana import LongSanaVideoPipeline
|
||||
from .pipeline_sana_video import SanaVideoPipeline
|
||||
from .pipeline_sana_video_i2v import SanaImageToVideoPipeline
|
||||
else:
|
||||
|
||||
1241
src/diffusers/pipelines/sana_video/pipeline_longsana.py
Normal file
1241
src/diffusers/pipelines/sana_video/pipeline_longsana.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1383,6 +1383,21 @@ class SanaTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SanaVideoCausalTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SanaVideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1727,6 +1727,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongSanaVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user