Compare commits

...

2 Commits

Author SHA1 Message Date
Sayak Paul
2945a4fff7 up 2025-12-11 04:47:57 +00:00
Sayak Paul
2f947c423f start mako 2025-12-10 13:15:32 +00:00
4 changed files with 933 additions and 7 deletions

165
check_mako.py Normal file
View File

@@ -0,0 +1,165 @@
import torch
import torch.nn as nn
from diffusers import AutoModel, WanPipeline, WanTransformer3DModel
from diffusers.utils import export_to_video
import triton
from functools import partial
from argparse import ArgumentParser
CKPT_ID = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
@torch.no_grad()
def fuse_qkv_for_wan_transformer_3d_model(model: "WanTransformer3DModel") -> "WanTransformer3DModel":
"""
In-place Q/K/V fusion for WanTransformer3DModel.
For each WanTransformerBlock:
* attn1: create (w_qkv_self, b_qkv_self) by concatenating Q/K/V.
* attn2: create (w_kv_cross, b_kv_cross) by concatenating K/V.
The fused tensors are registered as nn.Parameters on the corresponding
WanAttention modules and populated via `load_state_dict`.
"""
for block in getattr(model, "blocks", []):
# ------------------------------------------------------------------
# 1. Self-attention: fuse Q, K, V -> (w_qkv_self, b_qkv_self)
# ------------------------------------------------------------------
attn1 = getattr(block, "attn1", None)
if attn1 is not None and not hasattr(attn1, "w_qkv_self"):
# Grab existing projections
w_q = attn1.to_q.weight.data
w_k = attn1.to_k.weight.data
w_v = attn1.to_v.weight.data
b_q = attn1.to_q.bias.data
b_k = attn1.to_k.bias.data
b_v = attn1.to_v.bias.data
# Fuse along the out_features dimension (dim=0)
fused_w = torch.cat([w_q, w_k, w_v], dim=0)
fused_b = torch.cat([b_q, b_k, b_v], dim=0)
out_features, in_features = fused_w.shape
device = fused_w.device
dtype = fused_w.dtype
# Register fused parameters with the requested names
attn1.register_parameter(
"w_qkv_self",
nn.Parameter(torch.empty((out_features, in_features), device=device, dtype=dtype)),
)
attn1.register_parameter(
"b_qkv_self",
nn.Parameter(torch.empty((out_features,), device=device, dtype=dtype)),
)
# Load via state-dict mechanism (so it works nicely with checkpoints)
attn1.load_state_dict(
{"w_qkv_self": fused_w, "b_qkv_self": fused_b},
strict=False,
)
# ------------------------------------------------------------------
# 2. Cross-attention: fuse K, V -> (w_kv_cross, b_kv_cross)
# ------------------------------------------------------------------
attn2 = getattr(block, "attn2", None)
if attn2 is not None and not hasattr(attn2, "w_kv_cross"):
w_k = attn2.to_k.weight.data
w_v = attn2.to_v.weight.data
b_k = attn2.to_k.bias.data
b_v = attn2.to_v.bias.data
fused_w = torch.cat([w_k, w_v], dim=0)
fused_b = torch.cat([b_k, b_v], dim=0)
out_features, in_features = fused_w.shape
device = fused_w.device
dtype = fused_w.dtype
attn2.register_parameter(
"w_kv_cross",
nn.Parameter(torch.empty((out_features, in_features), device=device, dtype=dtype)),
)
attn2.register_parameter(
"b_kv_cross",
nn.Parameter(torch.empty((out_features,), device=device, dtype=dtype)),
)
attn2.load_state_dict(
{"w_kv_cross": fused_w, "b_kv_cross": fused_b},
strict=False,
)
return model
def load_pipeline():
vae = AutoModel.from_pretrained(CKPT_ID, subfolder="vae", torch_dtype=torch.float32)
pipeline = WanPipeline.from_pretrained(
CKPT_ID, vae=vae, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
return pipeline
def get_prompts():
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
return prompt, negative_prompt
# Fixing batch size of 2 and `max_sequence_length` of 256 because of the kernels.
def run_inference(pipeline, prompt, negative_prompt, num_inference_steps=50):
output = pipeline(
prompt=[prompt] * 2,
negative_prompt=negative_prompt,
num_frames=81,
guidance_scale=5.0,
num_inference_steps=num_inference_steps,
max_sequence_length=256,
generator=torch.manual_seed(0)
).frames[0]
return output
def main(args):
pipe = load_pipeline()
if args.use_mako:
from diffusers.models.transformers import wan_mako_attention_processor
print("Using MaKO kernel.")
pipe.transformer = fuse_qkv_for_wan_transformer_3d_model(pipe.transformer)
pipe.transformer.set_attn_processor(wan_mako_attention_processor.WanMakoAttnProcessor())
if args.use_compile:
pipe.transformer.compile_repeated_blocks()
prompt, negative_prompt = get_prompts()
for _ in range(3):
_ = run_inference(pipe, prompt, negative_prompt, 1)
inference_func = partial(run_inference, pipe, prompt=prompt, negative_prompt=negative_prompt)
latency = triton.testing.do_bench(inference_func, warmup=1, rep=1)
print(f"{args=}, {latency=} seconds.")
output = inference_func()
export_to_video(output, f"output_mako@{args.use_mako}_compile@{args.use_compile}.mp4", fps=16)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--use_mako", action="store_true")
parser.add_argument("--use_compile", action="store_true")
args = parser.parse_args()
main(args)

View File

@@ -31,6 +31,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from .wan_mako_kernels import triton_adaptive_norm, triton_matmul, fused_matmul_residual, fused_matmul_residual_gate
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -461,6 +462,11 @@ class WanTransformerBlock(nn.Module):
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
# Notes: Instead of performing the output projections on the attention outputs in the attention block
# we perform them here to take advantage of fusion.
if not hidden_states.is_contiguous():
hidden_states = hidden_states.contiguous()
B, S, D = hidden_states.shape
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
@@ -480,21 +486,47 @@ class WanTransformerBlock(nn.Module):
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
# norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
# Fused Adaptive LayerNorm
norm_hidden_states = triton_adaptive_norm(hidden_states, scale_msa, shift_msa, self.norm1.eps)
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# Fused Output Projection + Gated Residual
hidden_states = fused_matmul_residual_gate(
attn_output, self.attn1.to_out[0].weight, self.attn1.to_out[0].bias,
hidden_states, gate_msa, S_rows=S
)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
hidden_states = hidden_states + attn_output
# hidden_states = hidden_states + attn_output
# Fused Cross-Attn Output Proj + Residual (no gate)
hidden_states = fused_matmul_residual(
attn_output, self.attn2.to_out[0].weight, self.attn2.to_out[0].bias, hidden_states
)
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
# norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
# hidden_states
# )
norm_hidden_states = triton_adaptive_norm(hidden_states, c_scale_msa, c_shift_msa, self.norm3.eps)
# ff_output = self.ffn(norm_hidden_states)
# hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
# Fused Linear + GELU
ff_in = triton_matmul(
norm_hidden_states,
self.ffn.net[0].proj.weight,
self.ffn.net[0].proj.bias,
activation="gelu"
)
# Fused Second Linear + Gated Residual
hidden_states = fused_matmul_residual_gate(
ff_in, self.ffn.net[2].weight, self.ffn.net[2].bias,
hidden_states, c_gate_msa, S_rows=S
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states

View File

@@ -0,0 +1,81 @@
import torch
import torch.nn.functional as F
from typing import Optional, Tuple
from .transformer_wan import WanAttention
from .wan_mako_kernels import triton_matmul, triton_rms_norm2
# TODO: incorporate I2V support
class WanMakoAttnProcessor:
def __call__(
self,
attn: "WanAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
B, S, D = hidden_states.shape
H = attn.heads
head_dim = D // H # or assert against attn.inner_dim if needed
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
# if you don't need this, drop it to avoid the slice
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
# --- QKV / KV projections -------------------------------------------------
if hasattr(attn, "w_qkv_self") and hasattr(attn, "b_qkv_self"):
# Fused QKV via single matmul (self-attention)
qkv = triton_matmul(hidden_states, attn.w_qkv_self, attn.b_qkv_self)
q, k, v = qkv.chunk(3, dim=-1)
else:
# Q Projection (self from hidden_states)
q = triton_matmul(hidden_states, attn.to_q.weight, attn.to_q.bias)
# Fused KV Projection (cross from encoder_hidden_states)
kv = triton_matmul(encoder_hidden_states, attn.w_kv_cross, attn.b_kv_cross)
k, v = kv.chunk(2, dim=-1)
# --- Fused RMS Norm for Q and K to reduce 1 launch ----------------------
q, k = triton_rms_norm2(
q, attn.norm_q.weight,
k, attn.norm_k.weight,
attn.norm_q.eps,
)
# --- Reshape ------------------------------------
q, k, v = (a.unflatten(2, (attn.heads, -1)) for a in (q, k, v))
# --- Rotary embedding -----------------------------------------------------
if rotary_emb is not None:
def apply_rotary_emb(
hidden_states: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)
q = apply_rotary_emb(q, *rotary_emb)
k = apply_rotary_emb(k, *rotary_emb)
# --- Scaled dot-product attention ----------------------------------------
q, k, v = (x.permute(0, 2, 1, 3) for x in (q, k, v))
attn_out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=0.0,
is_causal=False,
)
# (B, H, S, head_dim) -> (B, S, D)
attn_out = attn_out.transpose(1, 2).reshape(B, S, D)
return attn_out.contiguous() if not attn_out.is_contiguous() else attn_out

View File

@@ -0,0 +1,648 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
# ============================================================================
# Triton Kernels
# ============================================================================
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=3, num_warps=8),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(
A, B, C,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BIAS, HAS_BIAS: tl.constexpr,
ACT_GELU_TANH: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (k + offs_k[None, :] < K), other=0.0)
b = tl.load(b_ptrs, mask=(k + offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
if HAS_BIAS:
bias = tl.load(BIAS + offs_n, mask=offs_n < N, other=0.0)
acc += bias[None, :]
if ACT_GELU_TANH:
# Fused GELU (tanh approximation)
x = acc
c0 = 0.7978845608028654 # sqrt(2/pi)
c1 = 0.044715
x3 = x * x * x
inner = c0 * (x + c1 * x3)
e2 = tl.exp(2.0 * inner)
tanh_val = (e2 - 1.0) / (e2 + 1.0)
acc = 0.5 * x * (1.0 + tanh_val)
c_ptrs = C + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
@triton.jit
def adaptive_layernorm_kernel(
X, Scale, Shift, Out,
stride_xb, stride_xs, stride_xd,
stride_sb, stride_sd,
stride_tb, stride_td,
stride_ob, stride_os, stride_od,
N, eps,
BLOCK_N: tl.constexpr
):
pid_b = tl.program_id(0)
pid_s = tl.program_id(1)
off_x = pid_b * stride_xb + pid_s * stride_xs
off_s = pid_b * stride_sb
off_t = pid_b * stride_tb
off_out = pid_b * stride_ob + pid_s * stride_os
cols = tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(X + off_x + cols * stride_xd, mask=mask, other=0.0).to(tl.float32)
mean = tl.sum(x, axis=0) / N
x_centered = x - mean
var = tl.sum(x_centered * x_centered, axis=0) / N
rstd = tl.rsqrt(var + eps)
norm = x_centered * rstd
scale = tl.load(Scale + off_s + cols * stride_sd, mask=mask, other=0.0).to(tl.float32)
shift = tl.load(Shift + off_t + cols * stride_td, mask=mask, other=0.0).to(tl.float32)
out = norm * (1.0 + scale) + shift
tl.store(Out + off_out + cols * stride_od, out, mask=mask)
@triton.jit
def rms_norm_kernel(
X, W, Out,
stride_xb, stride_xs, stride_xd,
stride_w,
stride_ob, stride_os, stride_od,
N, eps,
BLOCK_N: tl.constexpr
):
pid_b = tl.program_id(0)
pid_s = tl.program_id(1)
off_x = pid_b * stride_xb + pid_s * stride_xs
off_out = pid_b * stride_ob + pid_s * stride_os
cols = tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(X + off_x + cols * stride_xd, mask=mask, other=0.0).to(tl.float32)
ms = tl.sum(x * x, axis=0) / N
rms = tl.rsqrt(ms + eps)
w = tl.load(W + cols * stride_w, mask=mask, other=0.0).to(tl.float32)
out = x * rms * w
tl.store(Out + off_out + cols * stride_od, out, mask=mask)
# New: fused RMSNorm for (Q, K) together to eliminate one kernel launch and reuse scheduling
@triton.jit
def rms_norm2_kernel(
X1, W1, Out1,
X2, W2, Out2,
stride_xb, stride_xs, stride_xd,
stride_w1, stride_w2,
stride_o1b, stride_o1s, stride_o1d,
stride_o2b, stride_o2s, stride_o2d,
N, eps,
BLOCK_N: tl.constexpr
):
pid_b = tl.program_id(0)
pid_s = tl.program_id(1)
off_x = pid_b * stride_xb + pid_s * stride_xs
cols = tl.arange(0, BLOCK_N)
mask = cols < N
x1 = tl.load(X1 + off_x + cols * stride_xd, mask=mask, other=0.0).to(tl.float32)
x2 = tl.load(X2 + off_x + cols * stride_xd, mask=mask, other=0.0).to(tl.float32)
ms1 = tl.sum(x1 * x1, axis=0) / N
ms2 = tl.sum(x2 * x2, axis=0) / N
rms1 = tl.rsqrt(ms1 + eps)
rms2 = tl.rsqrt(ms2 + eps)
w1 = tl.load(W1 + cols * stride_w1, mask=mask, other=0.0).to(tl.float32)
w2 = tl.load(W2 + cols * stride_w2, mask=mask, other=0.0).to(tl.float32)
y1 = x1 * rms1 * w1
y2 = x2 * rms2 * w2
tl.store(Out1 + pid_b * stride_o1b + pid_s * stride_o1s + cols * stride_o1d, y1, mask=mask)
tl.store(Out2 + pid_b * stride_o2b + pid_s * stride_o2s + cols * stride_o2d, y2, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=3, num_warps=8),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_bias_resgate_kernel(
A, B, X, G, C,
M, N, K, S_rows,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_xm, stride_xn,
stride_gb, stride_gs, stride_gn,
stride_cm, stride_cn,
BIAS, HAS_BIAS: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # rows
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # cols
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (k + offs_k[None, :] < K), other=0.0)
b = tl.load(b_ptrs, mask=(k + offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
if HAS_BIAS:
bias = tl.load(BIAS + offs_n, mask=offs_n < N, other=0.0)
acc += bias[None, :]
# Load residual X
x_ptrs = X + (offs_m[:, None] * stride_xm + offs_n[None, :] * stride_xn)
x_val = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32)
# Compute gate index using batch from rows
b_rows = (offs_m // S_rows)[:, None]
g_ptrs = G + (b_rows * stride_gb + 0 * stride_gs + offs_n[None, :] * stride_gn)
g_val = tl.load(g_ptrs, mask=(offs_n[None, :] < N), other=0.0).to(tl.float32)
out = x_val + acc * g_val
c_ptrs = C + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
tl.store(c_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=3, num_warps=8),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_bias_resadd_kernel(
A, B, X, C,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_xm, stride_xn,
stride_cm, stride_cn,
BIAS, HAS_BIAS: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # rows
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # cols
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (k + offs_k[None, :] < K), other=0.0)
b = tl.load(b_ptrs, mask=(k + offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
if HAS_BIAS:
bias = tl.load(BIAS + offs_n, mask=offs_n < N, other=0.0)
acc += bias[None, :]
# Load residual and add
x_ptrs = X + (offs_m[:, None] * stride_xm + offs_n[None, :] * stride_xn)
x_val = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32)
out = x_val + acc
c_ptrs = C + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
tl.store(c_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
# ============================================================================
# Helpers
# ============================================================================
def triton_matmul(x, w, bias=None, activation=""):
is_3d = x.ndim == 3
if is_3d:
B, S, K = x.shape
M = B * S
x_2d = x.view(M, K)
else:
M, K = x.shape
x_2d = x
N = w.shape[0]
out = torch.empty((M, N), device=x.device, dtype=x.dtype)
stride_am, stride_ak = x_2d.stride(0), x_2d.stride(1)
stride_bk, stride_bn = w.stride(1), w.stride(0)
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']), triton.cdiv(N, meta['BLOCK_N']))
has_bias = bias is not None
is_gelu = activation == "gelu"
matmul_kernel[grid](
x_2d, w, out,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
out.stride(0), out.stride(1),
bias if has_bias else x_2d,
HAS_BIAS=has_bias,
ACT_GELU_TANH=is_gelu
)
if is_3d:
return out.view(B, S, N)
return out
def triton_adaptive_norm(x, scale, shift, eps):
B, S, D = x.shape
out = torch.empty_like(x)
def get_strides_bd(t):
if t.ndim == 3: return t.stride(0), t.stride(2)
return t.stride(0), t.stride(1)
ss_b, ss_d = get_strides_bd(scale)
st_b, st_d = get_strides_bd(shift)
BLOCK_N = triton.next_power_of_2(D)
grid = (B, S)
adaptive_layernorm_kernel[grid](
x, scale, shift, out,
x.stride(0), x.stride(1), x.stride(2),
ss_b, ss_d,
st_b, st_d,
out.stride(0), out.stride(1), out.stride(2),
D, eps,
BLOCK_N=BLOCK_N
)
return out
def triton_rms_norm(x, weight, eps):
B, S, D = x.shape
out = torch.empty_like(x)
BLOCK_N = triton.next_power_of_2(D)
grid = (B, S)
rms_norm_kernel[grid](
x, weight, out,
x.stride(0), x.stride(1), x.stride(2),
weight.stride(0),
out.stride(0), out.stride(1), out.stride(2),
D, eps,
BLOCK_N=BLOCK_N
)
return out
def triton_rms_norm2(x1, w1, x2, w2, eps):
# x1, x2: (B, S, D) with same B,S,D
B, S, D = x1.shape
out1 = torch.empty_like(x1)
out2 = torch.empty_like(x2)
BLOCK_N = triton.next_power_of_2(D)
grid = (B, S)
rms_norm2_kernel[grid](
x1, w1, out1,
x2, w2, out2,
x1.stride(0), x1.stride(1), x1.stride(2),
w1.stride(0), w2.stride(0),
out1.stride(0), out1.stride(1), out1.stride(2),
out2.stride(0), out2.stride(1), out2.stride(2),
D, eps,
BLOCK_N=BLOCK_N
)
return out1, out2
def fused_matmul_residual_gate(A, W, bias, X, G, S_rows):
# A: (B, S, K), W: (N, K), X: (B, S, N), G: (B, 1 or None, N) or (B, N)
B, S, K = A.shape
M = B * S
N = W.shape[0]
A2d = A.view(M, K)
X2d = X.view(M, N)
out = torch.empty_like(X2d)
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']), triton.cdiv(N, meta['BLOCK_N']))
has_bias = bias is not None
# Strides for gate (B, 1, N) or (B, N)
stride_gb = G.stride(0)
stride_gs = 0 if (G.ndim == 3 and G.shape[1] == 1) else (G.stride(1) if G.ndim == 3 else 0)
stride_gn = G.stride(-1)
matmul_bias_resgate_kernel[grid](
A2d, W, X2d, G, out,
M, N, K, S_rows,
A2d.stride(0), A2d.stride(1),
W.stride(1), W.stride(0),
X2d.stride(0), X2d.stride(1),
stride_gb, stride_gs, stride_gn,
out.stride(0), out.stride(1),
bias if has_bias else W, # dummy if no bias
HAS_BIAS=has_bias
)
return out.view(B, S, N)
def fused_matmul_residual(A, W, bias, X):
# A: (B, S, K), W: (N, K), X: (B, S, N)
B, S, K = A.shape
M = B * S
N = W.shape[0]
A2d = A.view(M, K)
X2d = X.view(M, N)
out = torch.empty_like(X2d)
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']), triton.cdiv(N, meta['BLOCK_N']))
has_bias = bias is not None
matmul_bias_resadd_kernel[grid](
A2d, W, X2d, out,
M, N, K,
A2d.stride(0), A2d.stride(1),
W.stride(1), W.stride(0),
X2d.stride(0), X2d.stride(1),
out.stride(0), out.stride(1),
bias if has_bias else W,
HAS_BIAS=has_bias
)
return out.view(B, S, N)
# ============================================================================
# Optimized Model
# ============================================================================
class ModelNew(nn.Module):
def __init__(
self,
dim: int = 1536,
ffn_dim: int = 8960,
num_heads: int = 12,
cross_attn_norm: bool = False,
eps: float = 1e-6,
) -> None:
super().__init__()
# Preserve parameter structure (state_dict compatibility)
class WanAttention(nn.Module):
def __init__(self, dim, heads, dim_head, eps, dropout, added_kv_proj_dim, cross_attention_dim_head):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.dim_head = dim_head
self.added_kv_proj_dim = added_kv_proj_dim
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.to_q = nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, dim, bias=True),
nn.Dropout(dropout),
])
self.norm_q = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.add_k_proj = self.add_v_proj = None
if added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = nn.RMSNorm(dim_head * heads, eps=eps)
class FeedForward(nn.Module):
def __init__(self, dim, inner_dim, dropout, bias=True):
super().__init__()
self.net = nn.ModuleList([])
class GELU(nn.Module):
def __init__(self, dim_in, dim_out, approximate, bias):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
self.net.append(GELU(dim, inner_dim, "tanh", bias))
self.net.append(nn.Dropout(dropout))
self.net.append(nn.Linear(inner_dim, dim, bias=bias))
class FP32LayerNorm(nn.LayerNorm):
pass # Placeholder for structure
class WanTransformerBlock(nn.Module):
def __init__(self, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim):
super().__init__()
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = WanAttention(dim, num_heads, dim // num_heads, eps, 0.0, None, None)
self.attn2 = WanAttention(dim, num_heads, dim // num_heads, eps, 0.0, added_kv_proj_dim, dim // num_heads)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.ffn = FeedForward(dim, ffn_dim, 0.0, bias=True)
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.eps = eps
self.block = WanTransformerBlock(dim, ffn_dim, num_heads, "rms_norm_across_heads", cross_attn_norm, eps, None)
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
@torch.no_grad()
def fuse(self):
# Pre-concatenate weights for fused operations
# Self-Attention QKV
self.w_qkv_self = torch.cat([self.block.attn1.to_q.weight, self.block.attn1.to_k.weight, self.block.attn1.to_v.weight], dim=0)
self.b_qkv_self = torch.cat([self.block.attn1.to_q.bias, self.block.attn1.to_k.bias, self.block.attn1.to_v.bias], dim=0)
# Cross-Attention KV
self.w_kv_cross = torch.cat([self.block.attn2.to_k.weight, self.block.attn2.to_v.weight], dim=0)
self.b_kv_cross = torch.cat([self.block.attn2.to_k.bias, self.block.attn2.to_v.bias], dim=0)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
) -> torch.Tensor:
block = self.block
B, S, D = hidden_states.shape
H = self.num_heads
Dh = self.head_dim
# Modulation
if temb.ndim == 4:
mods = (block.scale_shift_table.unsqueeze(0) + temb.float()).chunk(6, dim=2)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [x.squeeze(2) for x in mods]
else:
mods = (block.scale_shift_table + temb.float()).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = mods
# --------------------------------------------------------------------
# 1. Self-Attention
# --------------------------------------------------------------------
# Fused Adaptive LayerNorm
norm_hidden = triton_adaptive_norm(hidden_states, scale_msa, shift_msa, block.eps)
# Fused QKV via single matmul
qkv = triton_matmul(norm_hidden, self.w_qkv_self, self.b_qkv_self)
q, k, v = qkv.chunk(3, dim=-1)
# Fused RMS Norm for Q and K to reduce 1 launch
q, k = triton_rms_norm2(q, block.attn1.norm_q.weight, k, block.attn1.norm_k.weight, block.attn1.norm_q.eps)
# Reshape & Attention
q = q.view(B, S, H, Dh).transpose(1, 2)
k = k.view(B, S, H, Dh).transpose(1, 2)
v = v.view(B, S, H, Dh).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, D)
# Fused Output Projection + Gated Residual
hidden_states = fused_matmul_residual_gate(
attn_out, block.attn1.to_out[0].weight, block.attn1.to_out[0].bias,
hidden_states, gate_msa, S_rows=S
)
# --------------------------------------------------------------------
# 2. Cross-Attention
# --------------------------------------------------------------------
# Norm
if not isinstance(block.norm2, nn.Identity):
norm_hidden = block.norm2(hidden_states.float()).type_as(hidden_states)
else:
norm_hidden = hidden_states
# Q Projection
q2 = triton_matmul(norm_hidden, block.attn2.to_q.weight, block.attn2.to_q.bias)
# Fused KV Projection
kv2 = triton_matmul(encoder_hidden_states, self.w_kv_cross, self.b_kv_cross)
k2, v2 = kv2.chunk(2, dim=-1)
# RMS Norm
q2 = triton_rms_norm(q2, block.attn2.norm_q.weight, block.attn2.norm_q.eps)
k2 = triton_rms_norm(k2, block.attn2.norm_k.weight, block.attn2.norm_k.eps)
# Attention
q2 = q2.view(B, S, H, Dh).transpose(1, 2)
T_text = encoder_hidden_states.shape[1]
k2 = k2.view(B, T_text, H, Dh).transpose(1, 2)
v2 = v2.view(B, T_text, H, Dh).transpose(1, 2)
attn_out2 = F.scaled_dot_product_attention(q2, k2, v2, dropout_p=0.0, is_causal=False)
attn_out2 = attn_out2.transpose(1, 2).contiguous().view(B, S, D)
# Fused Cross-Attn Output Proj + Residual (no gate)
hidden_states = fused_matmul_residual(
attn_out2, block.attn2.to_out[0].weight, block.attn2.to_out[0].bias, hidden_states
)
# --------------------------------------------------------------------
# 3. Feed-Forward
# --------------------------------------------------------------------
# Adaptive Norm
norm_hidden = triton_adaptive_norm(hidden_states, c_scale_msa, c_shift_msa, block.eps)
# Fused Linear + GELU
ff_in = triton_matmul(
norm_hidden,
block.ffn.net[0].proj.weight,
block.ffn.net[0].proj.bias,
activation="gelu"
)
# Fused Second Linear + Gated Residual
hidden_states = fused_matmul_residual_gate(
ff_in, block.ffn.net[2].weight, block.ffn.net[2].bias,
hidden_states, c_gate_msa, S_rows=S
)
return hidden_states
def get_inputs():
# randomly generate input tensors based on the model architecture (Wan 1.3B)
batch_size = 2
seq_len = 256 # Number of latent tokens (e.g., from video patches)
dim = 1536 # Hidden dimension for Wan 1.3B
# hidden_states: [batch_size, seq_len, dim]
hidden_states = torch.randn(batch_size, seq_len, dim).cuda()
# encoder_hidden_states: [batch_size, text_seq_len, dim] (text embeddings)
text_seq_len = 512
encoder_hidden_states = torch.randn(batch_size, text_seq_len, dim).cuda()
# temb: [batch_size, 6, dim] (timestep embedding projected to 6 modulation vectors)
temb = torch.randn(batch_size, 6, dim).cuda()
return [hidden_states, encoder_hidden_states, temb]
def get_init_inputs():
# Initialization parameters for Wan 1.3B: dim, ffn_dim, num_heads, cross_attn_norm, eps
return [1536, 8960, 12, False, 1e-6]
if __name__ == "__main__":
model = ModelNew(*get_init_inputs()).cuda()
model.fuse()
# Get inputs and run forward pass
inputs = get_inputs()
with torch.no_grad():
output = model(*inputs)
print(f"{output.shape=}")