Compare commits

...

6 Commits

Author SHA1 Message Date
sayakpaul
92199ff3ac up 2025-09-22 16:46:49 +05:30
sayakpaul
04e9323055 up 2025-09-18 17:23:04 +05:30
sayakpaul
9a09162baf up 2025-09-18 14:59:00 +05:30
sayakpaul
33a8a3be0c up 2025-09-18 14:49:48 +05:30
sayakpaul
58743c3ee7 kernelize gelu. 2025-09-16 18:09:12 +05:30
sayakpaul
50c0b786d2 start kernelize. 2025-09-15 16:26:52 +05:30
4 changed files with 122 additions and 12 deletions

View File

@@ -17,10 +17,11 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate
from ..utils.import_utils import is_torch_npu_available, is_torch_version
from ..utils import deprecate, get_logger, is_torch_npu_available, is_torch_version
logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu
@@ -31,6 +32,7 @@ ACT2CLS = {
"gelu": nn.GELU,
"relu": nn.ReLU,
}
KERNELS_REPO_ID = "kernels-community/activation"
def get_activation(act_fn: str) -> nn.Module:
@@ -90,6 +92,27 @@ class GELU(nn.Module):
return hidden_states
# TODO: validation checks / consider making Python classes of activations like `transformers`
# All of these are temporary for now.
class CUDAOptimizedGELU(GELU):
def __init__(self, *args, **kwargs):
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", revision="add_more_act")
approximate = kwargs.get("approximate", "none")
super().__init__(*args, **kwargs)
if approximate == "none":
self.act_fn = activation.layers.Gelu()
elif approximate == "tanh":
self.act_fn = activation.layers.GeluTanh()
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.act_fn(hidden_states)
return hidden_states
class GEGLU(nn.Module):
r"""
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.

View File

@@ -20,11 +20,20 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import is_torch_npu_available, is_torch_version
from ..utils import is_kernels_available, is_torch_npu_available, is_torch_version
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
from ..utils.kernels_utils import use_kernel_forward_from_hub
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", revision="add_more_act")
silu_kernel = activation.layers.Silu
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
@@ -57,7 +66,10 @@ class AdaLayerNorm(nn.Module):
else:
self.emb = None
self.silu = nn.SiLU()
if DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = silu_kernel()
else:
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
@@ -144,7 +156,10 @@ class AdaLayerNormZero(nn.Module):
else:
self.emb = None
self.silu = nn.SiLU()
if DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = silu_kernel()
else:
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -183,7 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()
self.silu = nn.SiLU()
if DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = silu_kernel()
else:
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -335,7 +353,10 @@ class AdaLayerNormContinuous(nn.Module):
norm_type="layer_norm",
):
super().__init__()
self.silu = nn.SiLU()
if DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = silu_kernel()
else:
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
@@ -508,6 +529,7 @@ else:
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
@use_kernel_forward_from_hub("RMSNorm")
class RMSNorm(nn.Module):
r"""
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.

View File

@@ -22,7 +22,8 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -40,6 +41,12 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", revision="add_more_act")
gelu_tanh_kernel = activation.layers.GeluTanh
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
@@ -300,8 +307,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
if DIFFUSERS_ENABLE_HUB_KERNELS:
from ..normalization import RMSNorm
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
else:
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
@@ -312,8 +325,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
if DIFFUSERS_ENABLE_HUB_KERNELS:
from ..normalization import RMSNorm
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
else:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
@@ -351,6 +370,11 @@ class FluxSingleTransformerBlock(nn.Module):
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
# if not DIFFUSERS_ENABLE_HUB_KERNELS:
# self.act_mlp = nn.GELU(approximate="tanh")
# else:
# self.act_mlp = gelu_tanh_kernel()
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
self.attn = FluxAttention(

View File

@@ -1,3 +1,5 @@
from typing import Union
from ..utils import get_logger
from .import_utils import is_kernels_available
@@ -21,3 +23,42 @@ def _get_fa3_from_hub():
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
if is_kernels_available():
from kernels import (
Device,
LayerRepository,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
"RMSNorm": {
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
},
}
register_kernel_mapping(_KERNEL_MAPPING)
else:
# Stub to make decorators int transformers work when `kernels`
# is not installed.
def use_kernel_forward_from_hub(*args, **kwargs):
def decorator(cls):
return cls
return decorator
class LayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
def replace_kernel_forward_from_hub(*args, **kwargs):
raise RuntimeError(
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
)
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")