mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-24 17:38:15 +08:00
Compare commits
2 Commits
group-offl
...
dynamic-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1df740aac | ||
|
|
8d20369792 |
@@ -299,7 +299,10 @@ def get_cached_module_file(
|
|||||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
|
||||||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
if subfolder is not None:
|
||||||
|
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
|
||||||
|
else:
|
||||||
|
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
||||||
|
|
||||||
if os.path.isfile(module_file_or_url):
|
if os.path.isfile(module_file_or_url):
|
||||||
resolved_module_file = module_file_or_url
|
resolved_module_file = module_file_or_url
|
||||||
@@ -384,7 +387,11 @@ def get_cached_module_file(
|
|||||||
if not os.path.exists(submodule_path / module_folder):
|
if not os.path.exists(submodule_path / module_folder):
|
||||||
os.makedirs(submodule_path / module_folder)
|
os.makedirs(submodule_path / module_folder)
|
||||||
module_needed = f"{module_needed}.py"
|
module_needed = f"{module_needed}.py"
|
||||||
shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
if subfolder is not None:
|
||||||
|
source_path = os.path.join(pretrained_model_name_or_path, subfolder, module_needed)
|
||||||
|
else:
|
||||||
|
source_path = os.path.join(pretrained_model_name_or_path, module_needed)
|
||||||
|
shutil.copyfile(source_path, submodule_path / module_needed)
|
||||||
else:
|
else:
|
||||||
# Get the commit hash
|
# Get the commit hash
|
||||||
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import CLIPTextModel, LongformerModel
|
from transformers import CLIPTextModel, LongformerModel
|
||||||
|
|
||||||
from diffusers.models import AutoModel, UNet2DConditionModel
|
from diffusers.models import AutoModel, UNet2DConditionModel
|
||||||
@@ -35,6 +39,45 @@ class TestAutoModel(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
assert isinstance(model, CLIPTextModel)
|
assert isinstance(model, CLIPTextModel)
|
||||||
|
|
||||||
|
def test_load_dynamic_module_from_local_path_with_subfolder(self):
|
||||||
|
CUSTOM_MODEL_CODE = (
|
||||||
|
"import torch\n"
|
||||||
|
"from diffusers import ModelMixin, ConfigMixin\n"
|
||||||
|
"from diffusers.configuration_utils import register_to_config\n"
|
||||||
|
"\n"
|
||||||
|
"class CustomModel(ModelMixin, ConfigMixin):\n"
|
||||||
|
" @register_to_config\n"
|
||||||
|
" def __init__(self, hidden_size=8):\n"
|
||||||
|
" super().__init__()\n"
|
||||||
|
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
|
||||||
|
"\n"
|
||||||
|
" def forward(self, x):\n"
|
||||||
|
" return self.linear(x)\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
subfolder = "custom_model"
|
||||||
|
model_dir = os.path.join(tmpdir, subfolder)
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
|
||||||
|
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
|
||||||
|
f.write(CUSTOM_MODEL_CODE)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"_class_name": "CustomModel",
|
||||||
|
"_diffusers_version": "0.0.0",
|
||||||
|
"auto_map": {"AutoModel": "modeling.CustomModel"},
|
||||||
|
"hidden_size": 8,
|
||||||
|
}
|
||||||
|
with open(os.path.join(model_dir, "config.json"), "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
|
||||||
|
assert model.__class__.__name__ == "CustomModel"
|
||||||
|
assert model.config["hidden_size"] == 8
|
||||||
|
|
||||||
|
|
||||||
class TestAutoModelFromConfig(unittest.TestCase):
|
class TestAutoModelFromConfig(unittest.TestCase):
|
||||||
@patch(
|
@patch(
|
||||||
|
|||||||
Reference in New Issue
Block a user