mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 00:14:23 +08:00
Compare commits
6 Commits
pipeline-s
...
kernelize
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92199ff3ac | ||
|
|
04e9323055 | ||
|
|
9a09162baf | ||
|
|
33a8a3be0c | ||
|
|
58743c3ee7 | ||
|
|
50c0b786d2 |
@@ -17,10 +17,11 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..utils import deprecate
|
from ..utils import deprecate, get_logger, is_torch_npu_available, is_torch_version
|
||||||
from ..utils.import_utils import is_torch_npu_available, is_torch_version
|
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ ACT2CLS = {
|
|||||||
"gelu": nn.GELU,
|
"gelu": nn.GELU,
|
||||||
"relu": nn.ReLU,
|
"relu": nn.ReLU,
|
||||||
}
|
}
|
||||||
|
KERNELS_REPO_ID = "kernels-community/activation"
|
||||||
|
|
||||||
|
|
||||||
def get_activation(act_fn: str) -> nn.Module:
|
def get_activation(act_fn: str) -> nn.Module:
|
||||||
@@ -90,6 +92,27 @@ class GELU(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: validation checks / consider making Python classes of activations like `transformers`
|
||||||
|
# All of these are temporary for now.
|
||||||
|
class CUDAOptimizedGELU(GELU):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||||
|
approximate = kwargs.get("approximate", "none")
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if approximate == "none":
|
||||||
|
self.act_fn = activation.layers.Gelu()
|
||||||
|
elif approximate == "tanh":
|
||||||
|
self.act_fn = activation.layers.GeluTanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
hidden_states = self.act_fn(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
|
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
|
||||||
|
|||||||
@@ -20,11 +20,20 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..utils import is_torch_npu_available, is_torch_version
|
from ..utils import is_kernels_available, is_torch_npu_available, is_torch_version
|
||||||
|
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||||
|
from ..utils.kernels_utils import use_kernel_forward_from_hub
|
||||||
from .activations import get_activation
|
from .activations import get_activation
|
||||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||||
|
silu_kernel = activation.layers.Silu
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNorm(nn.Module):
|
class AdaLayerNorm(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Norm layer modified to incorporate timestep embeddings.
|
Norm layer modified to incorporate timestep embeddings.
|
||||||
@@ -57,7 +66,10 @@ class AdaLayerNorm(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.emb = None
|
self.emb = None
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
self.silu = silu_kernel()
|
||||||
|
else:
|
||||||
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
@@ -144,7 +156,10 @@ class AdaLayerNormZero(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.emb = None
|
self.emb = None
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
self.silu = silu_kernel()
|
||||||
|
else:
|
||||||
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
||||||
if norm_type == "layer_norm":
|
if norm_type == "layer_norm":
|
||||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||||
@@ -183,7 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
|
|||||||
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
self.silu = silu_kernel()
|
||||||
|
else:
|
||||||
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
||||||
if norm_type == "layer_norm":
|
if norm_type == "layer_norm":
|
||||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||||
@@ -335,7 +353,10 @@ class AdaLayerNormContinuous(nn.Module):
|
|||||||
norm_type="layer_norm",
|
norm_type="layer_norm",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu = nn.SiLU()
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
self.silu = silu_kernel()
|
||||||
|
else:
|
||||||
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||||
if norm_type == "layer_norm":
|
if norm_type == "layer_norm":
|
||||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||||
@@ -508,6 +529,7 @@ else:
|
|||||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
|
||||||
|
from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||||
from ...utils.torch_utils import maybe_allow_in_graph
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||||
from ..attention_dispatch import dispatch_attention_fn
|
from ..attention_dispatch import dispatch_attention_fn
|
||||||
@@ -40,6 +41,12 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||||
|
gelu_tanh_kernel = activation.layers.GeluTanh
|
||||||
|
|
||||||
|
|
||||||
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||||
query = attn.to_q(hidden_states)
|
query = attn.to_q(hidden_states)
|
||||||
@@ -300,8 +307,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
|||||||
self.added_kv_proj_dim = added_kv_proj_dim
|
self.added_kv_proj_dim = added_kv_proj_dim
|
||||||
self.added_proj_bias = added_proj_bias
|
self.added_proj_bias = added_proj_bias
|
||||||
|
|
||||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
from ..normalization import RMSNorm
|
||||||
|
|
||||||
|
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
else:
|
||||||
|
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
@@ -312,8 +325,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
|||||||
self.to_out.append(torch.nn.Dropout(dropout))
|
self.to_out.append(torch.nn.Dropout(dropout))
|
||||||
|
|
||||||
if added_kv_proj_dim is not None:
|
if added_kv_proj_dim is not None:
|
||||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
from ..normalization import RMSNorm
|
||||||
|
|
||||||
|
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||||
|
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||||
|
else:
|
||||||
|
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||||
|
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
@@ -351,6 +370,11 @@ class FluxSingleTransformerBlock(nn.Module):
|
|||||||
self.norm = AdaLayerNormZeroSingle(dim)
|
self.norm = AdaLayerNormZeroSingle(dim)
|
||||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||||
self.act_mlp = nn.GELU(approximate="tanh")
|
self.act_mlp = nn.GELU(approximate="tanh")
|
||||||
|
# if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
# self.act_mlp = nn.GELU(approximate="tanh")
|
||||||
|
# else:
|
||||||
|
# self.act_mlp = gelu_tanh_kernel()
|
||||||
|
|
||||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||||
|
|
||||||
self.attn = FluxAttention(
|
self.attn = FluxAttention(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
from ..utils import get_logger
|
from ..utils import get_logger
|
||||||
from .import_utils import is_kernels_available
|
from .import_utils import is_kernels_available
|
||||||
|
|
||||||
@@ -21,3 +23,42 @@ def _get_fa3_from_hub():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if is_kernels_available():
|
||||||
|
from kernels import (
|
||||||
|
Device,
|
||||||
|
LayerRepository,
|
||||||
|
register_kernel_mapping,
|
||||||
|
replace_kernel_forward_from_hub,
|
||||||
|
use_kernel_forward_from_hub,
|
||||||
|
)
|
||||||
|
|
||||||
|
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
|
||||||
|
"RMSNorm": {
|
||||||
|
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
register_kernel_mapping(_KERNEL_MAPPING)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Stub to make decorators int transformers work when `kernels`
|
||||||
|
# is not installed.
|
||||||
|
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||||
|
def decorator(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
class LayerRepository:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
|
||||||
|
|
||||||
|
def replace_kernel_forward_from_hub(*args, **kwargs):
|
||||||
|
raise RuntimeError(
|
||||||
|
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_kernel_mapping(*args, **kwargs):
|
||||||
|
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||||
|
|||||||
Reference in New Issue
Block a user