Compare commits

...

9 Commits

Author SHA1 Message Date
sayakpaul
45e5e03d03 up 2026-02-27 17:41:44 +05:30
Christopher
5910a1cc6c Fixing Kohya loras loading: Flux.1-dev loras with TE ("lora_te1_" prefix) (#13188)
* fixing text encoder lora loading

* following Cursor's review
2026-02-27 15:43:41 +05:30
Sayak Paul
6663decbce Merge branch 'main' into update-kernel-hub-repos 2026-02-25 07:28:50 +05:30
Sayak Paul
2b7ed4c8dc Merge branch 'main' into update-kernel-hub-repos 2026-02-20 16:20:11 +05:30
sayakpaul
67f4691cab resolve conflicts. 2026-02-19 18:22:49 +05:30
sayakpaul
e10fe61303 fix version and force updated kernels. 2026-02-19 18:00:01 +05:30
Sayak Paul
348350cf24 Merge branch 'main' into update-kernel-hub-repos 2026-02-19 17:53:46 +05:30
Sayak Paul
af35e3806c Merge branch 'main' into update-kernel-hub-repos 2026-02-19 09:35:15 +05:30
sayakpaul
d6bc647932 change to updated repo and version. 2026-02-18 23:46:06 +05:30
4 changed files with 34 additions and 8 deletions

View File

@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_"))
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
}
if any("text_projection" in k for k in state_dict):

View File

@@ -38,6 +38,7 @@ from ..utils import (
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_kernels_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
@@ -318,6 +319,7 @@ class _HubKernelConfig:
repo_id: str
function_attr: str
revision: str | None = None
version: int | None = None
kernel_fn: Callable | None = None
wrapped_forward_attr: str | None = None
wrapped_backward_attr: str | None = None
@@ -327,31 +329,34 @@ class _HubKernelConfig:
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_func",
revision="fake-ops-return-probs",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
version=1,
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_varlen_func",
# revision="fake-ops-return-probs",
version=1,
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
version=1,
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
version=1,
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
repo_id="kernels-community/sage-attention",
function_attr="sageattn",
version=1,
),
}
@@ -521,6 +526,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
raise RuntimeError(
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
if not is_kernels_version(">=", "0.12"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:

View File

@@ -86,6 +86,7 @@ from .import_utils import (
is_inflect_available,
is_invisible_watermark_available,
is_kernels_available,
is_kernels_version,
is_kornia_available,
is_librosa_available,
is_matplotlib_available,

View File

@@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_kernels_version(operation: str, version: str):
"""
Compares the current Kernels version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _kernels_available:
return False
return compare_versions(parse(_kernels_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str):
"""