mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-15 15:25:30 +08:00
Compare commits
7 Commits
torch-main
...
kernelize
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92199ff3ac | ||
|
|
04e9323055 | ||
|
|
9a09162baf | ||
|
|
33a8a3be0c | ||
|
|
58743c3ee7 | ||
|
|
50c0b786d2 | ||
|
|
f5c113e439 |
4
setup.py
4
setup.py
@@ -100,6 +100,7 @@ _deps = [
|
|||||||
"compel==0.1.8",
|
"compel==0.1.8",
|
||||||
"datasets",
|
"datasets",
|
||||||
"filelock",
|
"filelock",
|
||||||
|
"flax>=0.4.1",
|
||||||
"hf-doc-builder>=0.3.0",
|
"hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub>=0.34.0",
|
"huggingface-hub>=0.34.0",
|
||||||
"requests-mock==1.10.0",
|
"requests-mock==1.10.0",
|
||||||
@@ -136,7 +137,6 @@ _deps = [
|
|||||||
"requests",
|
"requests",
|
||||||
"tensorboard",
|
"tensorboard",
|
||||||
"tiktoken>=0.7.0",
|
"tiktoken>=0.7.0",
|
||||||
"flax>=0.4.1",
|
|
||||||
"torch>=1.4",
|
"torch>=1.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"transformers>=4.41.2",
|
"transformers>=4.41.2",
|
||||||
@@ -252,7 +252,6 @@ if os.name == "nt": # windows
|
|||||||
else:
|
else:
|
||||||
extras["flax"] = deps_list("jax", "jaxlib", "flax")
|
extras["flax"] = deps_list("jax", "jaxlib", "flax")
|
||||||
|
|
||||||
|
|
||||||
extras["dev"] = (
|
extras["dev"] = (
|
||||||
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
|
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
|
||||||
)
|
)
|
||||||
@@ -266,7 +265,6 @@ install_requires = [
|
|||||||
deps["requests"],
|
deps["requests"],
|
||||||
deps["safetensors"],
|
deps["safetensors"],
|
||||||
deps["Pillow"],
|
deps["Pillow"],
|
||||||
deps["torch"],
|
|
||||||
]
|
]
|
||||||
|
|
||||||
version_range_max = max(sys.version_info[1], 10) + 1
|
version_range_max = max(sys.version_info[1], 10) + 1
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ deps = {
|
|||||||
"compel": "compel==0.1.8",
|
"compel": "compel==0.1.8",
|
||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
|
"flax": "flax>=0.4.1",
|
||||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub": "huggingface-hub>=0.34.0",
|
"huggingface-hub": "huggingface-hub>=0.34.0",
|
||||||
"requests-mock": "requests-mock==1.10.0",
|
"requests-mock": "requests-mock==1.10.0",
|
||||||
@@ -43,7 +44,6 @@ deps = {
|
|||||||
"requests": "requests",
|
"requests": "requests",
|
||||||
"tensorboard": "tensorboard",
|
"tensorboard": "tensorboard",
|
||||||
"tiktoken": "tiktoken>=0.7.0",
|
"tiktoken": "tiktoken>=0.7.0",
|
||||||
"flax": "flax>=0.4.1",
|
|
||||||
"torch": "torch>=1.4",
|
"torch": "torch>=1.4",
|
||||||
"torchvision": "torchvision",
|
"torchvision": "torchvision",
|
||||||
"transformers": "transformers>=4.41.2",
|
"transformers": "transformers>=4.41.2",
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..utils import deprecate
|
from ..utils import deprecate, get_logger, is_torch_npu_available, is_torch_version
|
||||||
from ..utils.import_utils import is_torch_npu_available, is_torch_version
|
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ ACT2CLS = {
|
|||||||
"gelu": nn.GELU,
|
"gelu": nn.GELU,
|
||||||
"relu": nn.ReLU,
|
"relu": nn.ReLU,
|
||||||
}
|
}
|
||||||
|
KERNELS_REPO_ID = "kernels-community/activation"
|
||||||
|
|
||||||
|
|
||||||
def get_activation(act_fn: str) -> nn.Module:
|
def get_activation(act_fn: str) -> nn.Module:
|
||||||
@@ -90,6 +92,27 @@ class GELU(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: validation checks / consider making Python classes of activations like `transformers`
|
||||||
|
# All of these are temporary for now.
|
||||||
|
class CUDAOptimizedGELU(GELU):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||||
|
approximate = kwargs.get("approximate", "none")
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if approximate == "none":
|
||||||
|
self.act_fn = activation.layers.Gelu()
|
||||||
|
elif approximate == "tanh":
|
||||||
|
self.act_fn = activation.layers.GeluTanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
hidden_states = self.act_fn(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
|
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
|
||||||
|
|||||||
@@ -20,11 +20,20 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..utils import is_torch_npu_available, is_torch_version
|
from ..utils import is_kernels_available, is_torch_npu_available, is_torch_version
|
||||||
|
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||||
|
from ..utils.kernels_utils import use_kernel_forward_from_hub
|
||||||
from .activations import get_activation
|
from .activations import get_activation
|
||||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||||
|
silu_kernel = activation.layers.Silu
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNorm(nn.Module):
|
class AdaLayerNorm(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Norm layer modified to incorporate timestep embeddings.
|
Norm layer modified to incorporate timestep embeddings.
|
||||||
@@ -57,7 +66,10 @@ class AdaLayerNorm(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.emb = None
|
self.emb = None
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
self.silu = silu_kernel()
|
||||||
|
else:
|
||||||
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
@@ -144,7 +156,10 @@ class AdaLayerNormZero(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.emb = None
|
self.emb = None
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
self.silu = silu_kernel()
|
||||||
|
else:
|
||||||
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
||||||
if norm_type == "layer_norm":
|
if norm_type == "layer_norm":
|
||||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||||
@@ -183,7 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
|
|||||||
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
self.silu = silu_kernel()
|
||||||
|
else:
|
||||||
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
||||||
if norm_type == "layer_norm":
|
if norm_type == "layer_norm":
|
||||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||||
@@ -335,7 +353,10 @@ class AdaLayerNormContinuous(nn.Module):
|
|||||||
norm_type="layer_norm",
|
norm_type="layer_norm",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu = nn.SiLU()
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
self.silu = silu_kernel()
|
||||||
|
else:
|
||||||
|
self.silu = nn.SiLU()
|
||||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||||
if norm_type == "layer_norm":
|
if norm_type == "layer_norm":
|
||||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||||
@@ -508,6 +529,7 @@ else:
|
|||||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
|
||||||
|
from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||||
from ...utils.torch_utils import maybe_allow_in_graph
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||||
from ..attention_dispatch import dispatch_attention_fn
|
from ..attention_dispatch import dispatch_attention_fn
|
||||||
@@ -40,6 +41,12 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||||
|
gelu_tanh_kernel = activation.layers.GeluTanh
|
||||||
|
|
||||||
|
|
||||||
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||||
query = attn.to_q(hidden_states)
|
query = attn.to_q(hidden_states)
|
||||||
@@ -300,8 +307,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
|||||||
self.added_kv_proj_dim = added_kv_proj_dim
|
self.added_kv_proj_dim = added_kv_proj_dim
|
||||||
self.added_proj_bias = added_proj_bias
|
self.added_proj_bias = added_proj_bias
|
||||||
|
|
||||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
from ..normalization import RMSNorm
|
||||||
|
|
||||||
|
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
else:
|
||||||
|
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
@@ -312,8 +325,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
|||||||
self.to_out.append(torch.nn.Dropout(dropout))
|
self.to_out.append(torch.nn.Dropout(dropout))
|
||||||
|
|
||||||
if added_kv_proj_dim is not None:
|
if added_kv_proj_dim is not None:
|
||||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
from ..normalization import RMSNorm
|
||||||
|
|
||||||
|
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||||
|
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||||
|
else:
|
||||||
|
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||||
|
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
@@ -351,6 +370,11 @@ class FluxSingleTransformerBlock(nn.Module):
|
|||||||
self.norm = AdaLayerNormZeroSingle(dim)
|
self.norm = AdaLayerNormZeroSingle(dim)
|
||||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||||
self.act_mlp = nn.GELU(approximate="tanh")
|
self.act_mlp = nn.GELU(approximate="tanh")
|
||||||
|
# if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||||
|
# self.act_mlp = nn.GELU(approximate="tanh")
|
||||||
|
# else:
|
||||||
|
# self.act_mlp = gelu_tanh_kernel()
|
||||||
|
|
||||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||||
|
|
||||||
self.attn = FluxAttention(
|
self.attn = FluxAttention(
|
||||||
|
|||||||
@@ -505,6 +505,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
|
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
|
||||||
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
|
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
|
||||||
|
|
||||||
|
if dtype in (torch.bfloat16, None) and kwargs.pop("sdp_on_bf16", True):
|
||||||
|
if hasattr(torch._C, "_set_math_sdp_allow_fp16_bf16_reduction"):
|
||||||
|
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||||
|
logger.warning(
|
||||||
|
"Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`"
|
||||||
|
)
|
||||||
|
|
||||||
module_names, _ = self._get_signature_keys(self)
|
module_names, _ = self._get_signature_keys(self)
|
||||||
modules = [getattr(self, n, None) for n in module_names]
|
modules = [getattr(self, n, None) for n in module_names]
|
||||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
from ..utils import get_logger
|
from ..utils import get_logger
|
||||||
from .import_utils import is_kernels_available
|
from .import_utils import is_kernels_available
|
||||||
|
|
||||||
@@ -21,3 +23,42 @@ def _get_fa3_from_hub():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if is_kernels_available():
|
||||||
|
from kernels import (
|
||||||
|
Device,
|
||||||
|
LayerRepository,
|
||||||
|
register_kernel_mapping,
|
||||||
|
replace_kernel_forward_from_hub,
|
||||||
|
use_kernel_forward_from_hub,
|
||||||
|
)
|
||||||
|
|
||||||
|
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
|
||||||
|
"RMSNorm": {
|
||||||
|
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
register_kernel_mapping(_KERNEL_MAPPING)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Stub to make decorators int transformers work when `kernels`
|
||||||
|
# is not installed.
|
||||||
|
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||||
|
def decorator(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
class LayerRepository:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
|
||||||
|
|
||||||
|
def replace_kernel_forward_from_hub(*args, **kwargs):
|
||||||
|
raise RuntimeError(
|
||||||
|
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_kernel_mapping(*args, **kwargs):
|
||||||
|
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||||
|
|||||||
Reference in New Issue
Block a user