mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-02 23:00:39 +08:00
Compare commits
3 Commits
update-mod
...
dynamic-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1df740aac | ||
|
|
8d20369792 | ||
|
|
5910a1cc6c |
@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|||||||
)
|
)
|
||||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
||||||
|
|
||||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
|
||||||
if has_diffb:
|
if has_diffb:
|
||||||
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
||||||
if zero_status_diff_b:
|
if zero_status_diff_b:
|
||||||
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|||||||
state_dict = {
|
state_dict = {
|
||||||
_custom_replace(k, limit_substrings): v
|
_custom_replace(k, limit_substrings): v
|
||||||
for k, v in state_dict.items()
|
for k, v in state_dict.items()
|
||||||
if k.startswith(("lora_unet_", "lora_te_"))
|
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if any("text_projection" in k for k in state_dict):
|
if any("text_projection" in k for k in state_dict):
|
||||||
|
|||||||
@@ -299,7 +299,10 @@ def get_cached_module_file(
|
|||||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
|
||||||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
if subfolder is not None:
|
||||||
|
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
|
||||||
|
else:
|
||||||
|
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
||||||
|
|
||||||
if os.path.isfile(module_file_or_url):
|
if os.path.isfile(module_file_or_url):
|
||||||
resolved_module_file = module_file_or_url
|
resolved_module_file = module_file_or_url
|
||||||
@@ -384,7 +387,11 @@ def get_cached_module_file(
|
|||||||
if not os.path.exists(submodule_path / module_folder):
|
if not os.path.exists(submodule_path / module_folder):
|
||||||
os.makedirs(submodule_path / module_folder)
|
os.makedirs(submodule_path / module_folder)
|
||||||
module_needed = f"{module_needed}.py"
|
module_needed = f"{module_needed}.py"
|
||||||
shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
if subfolder is not None:
|
||||||
|
source_path = os.path.join(pretrained_model_name_or_path, subfolder, module_needed)
|
||||||
|
else:
|
||||||
|
source_path = os.path.join(pretrained_model_name_or_path, module_needed)
|
||||||
|
shutil.copyfile(source_path, submodule_path / module_needed)
|
||||||
else:
|
else:
|
||||||
# Get the commit hash
|
# Get the commit hash
|
||||||
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import CLIPTextModel, LongformerModel
|
from transformers import CLIPTextModel, LongformerModel
|
||||||
|
|
||||||
from diffusers.models import AutoModel, UNet2DConditionModel
|
from diffusers.models import AutoModel, UNet2DConditionModel
|
||||||
@@ -35,6 +39,45 @@ class TestAutoModel(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
assert isinstance(model, CLIPTextModel)
|
assert isinstance(model, CLIPTextModel)
|
||||||
|
|
||||||
|
def test_load_dynamic_module_from_local_path_with_subfolder(self):
|
||||||
|
CUSTOM_MODEL_CODE = (
|
||||||
|
"import torch\n"
|
||||||
|
"from diffusers import ModelMixin, ConfigMixin\n"
|
||||||
|
"from diffusers.configuration_utils import register_to_config\n"
|
||||||
|
"\n"
|
||||||
|
"class CustomModel(ModelMixin, ConfigMixin):\n"
|
||||||
|
" @register_to_config\n"
|
||||||
|
" def __init__(self, hidden_size=8):\n"
|
||||||
|
" super().__init__()\n"
|
||||||
|
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
|
||||||
|
"\n"
|
||||||
|
" def forward(self, x):\n"
|
||||||
|
" return self.linear(x)\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
subfolder = "custom_model"
|
||||||
|
model_dir = os.path.join(tmpdir, subfolder)
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
|
||||||
|
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
|
||||||
|
f.write(CUSTOM_MODEL_CODE)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"_class_name": "CustomModel",
|
||||||
|
"_diffusers_version": "0.0.0",
|
||||||
|
"auto_map": {"AutoModel": "modeling.CustomModel"},
|
||||||
|
"hidden_size": 8,
|
||||||
|
}
|
||||||
|
with open(os.path.join(model_dir, "config.json"), "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
|
||||||
|
assert model.__class__.__name__ == "CustomModel"
|
||||||
|
assert model.config["hidden_size"] == 8
|
||||||
|
|
||||||
|
|
||||||
class TestAutoModelFromConfig(unittest.TestCase):
|
class TestAutoModelFromConfig(unittest.TestCase):
|
||||||
@patch(
|
@patch(
|
||||||
|
|||||||
Reference in New Issue
Block a user