Compare commits

...

3 Commits

Author SHA1 Message Date
DN6
b1df740aac update 2026-02-28 12:02:38 +05:30
DN6
8d20369792 update 2026-02-28 11:19:49 +05:30
Christopher
5910a1cc6c Fixing Kohya loras loading: Flux.1-dev loras with TE ("lora_te1_" prefix) (#13188)
* fixing text encoder lora loading

* following Cursor's review
2026-02-27 15:43:41 +05:30
3 changed files with 54 additions and 4 deletions

View File

@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
) )
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")} state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict) has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
if has_diffb: if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b") zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b: if zero_status_diff_b:
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
state_dict = { state_dict = {
_custom_replace(k, limit_substrings): v _custom_replace(k, limit_substrings): v
for k, v in state_dict.items() for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_")) if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
} }
if any("text_projection" in k for k in state_dict): if any("text_projection" in k for k in state_dict):

View File

@@ -299,7 +299,10 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) if subfolder is not None:
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
else:
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
if os.path.isfile(module_file_or_url): if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url resolved_module_file = module_file_or_url
@@ -384,7 +387,11 @@ def get_cached_module_file(
if not os.path.exists(submodule_path / module_folder): if not os.path.exists(submodule_path / module_folder):
os.makedirs(submodule_path / module_folder) os.makedirs(submodule_path / module_folder)
module_needed = f"{module_needed}.py" module_needed = f"{module_needed}.py"
shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) if subfolder is not None:
source_path = os.path.join(pretrained_model_name_or_path, subfolder, module_needed)
else:
source_path = os.path.join(pretrained_model_name_or_path, module_needed)
shutil.copyfile(source_path, submodule_path / module_needed)
else: else:
# Get the commit hash # Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here. # TODO: we will get this info in the etag soon, so retrieve it from there and not here.

View File

@@ -1,6 +1,10 @@
import json
import os
import tempfile
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import torch
from transformers import CLIPTextModel, LongformerModel from transformers import CLIPTextModel, LongformerModel
from diffusers.models import AutoModel, UNet2DConditionModel from diffusers.models import AutoModel, UNet2DConditionModel
@@ -35,6 +39,45 @@ class TestAutoModel(unittest.TestCase):
) )
assert isinstance(model, CLIPTextModel) assert isinstance(model, CLIPTextModel)
def test_load_dynamic_module_from_local_path_with_subfolder(self):
CUSTOM_MODEL_CODE = (
"import torch\n"
"from diffusers import ModelMixin, ConfigMixin\n"
"from diffusers.configuration_utils import register_to_config\n"
"\n"
"class CustomModel(ModelMixin, ConfigMixin):\n"
" @register_to_config\n"
" def __init__(self, hidden_size=8):\n"
" super().__init__()\n"
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
"\n"
" def forward(self, x):\n"
" return self.linear(x)\n"
)
with tempfile.TemporaryDirectory() as tmpdir:
subfolder = "custom_model"
model_dir = os.path.join(tmpdir, subfolder)
os.makedirs(model_dir)
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
f.write(CUSTOM_MODEL_CODE)
config = {
"_class_name": "CustomModel",
"_diffusers_version": "0.0.0",
"auto_map": {"AutoModel": "modeling.CustomModel"},
"hidden_size": 8,
}
with open(os.path.join(model_dir, "config.json"), "w") as f:
json.dump(config, f)
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
assert model.__class__.__name__ == "CustomModel"
assert model.config["hidden_size"] == 8
class TestAutoModelFromConfig(unittest.TestCase): class TestAutoModelFromConfig(unittest.TestCase):
@patch( @patch(