mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
1 Commits
custom-blo
...
lora-model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93eee19d50 |
@@ -35,13 +35,13 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@@ -88,7 +88,12 @@ from ..testing_utils import (
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
|
||||
def caculate_expected_num_shards(index_map_path):
|
||||
@@ -1113,177 +1118,6 @@ class ModelTesterMixin:
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
self.assertTrue(torch.allclose(loaded_v, retrieved_v))
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_wrong_adapter_name_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wrong_name = "foo"
|
||||
with self.assertRaises(ValueError) as err_context:
|
||||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
|
||||
|
||||
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
|
||||
|
||||
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(model_file))
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(model_file))
|
||||
|
||||
# Perturb the metadata in the state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with self.assertRaises(TypeError) as err_context:
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_cpu_offload(self):
|
||||
if self.model_class._no_split_modules is None:
|
||||
@@ -1941,6 +1775,154 @@ class ModelTesterMixin:
|
||||
_ = loaded_model(**inputs_dict)
|
||||
|
||||
|
||||
class PEFTTesterMixin:
|
||||
@require_peft_backend
|
||||
@pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file)
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(model_file)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k, loaded_v in state_dict_loaded.items():
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert torch.allclose(loaded_v, retrieved_v)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora_wrong_adapter_name_raises_error(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError, match=rf"Adapter name {wrong_name} not found in the model\."):
|
||||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
|
||||
|
||||
@require_peft_backend
|
||||
@pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file)
|
||||
|
||||
# Perturb the metadata
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in list(lora_adapter_metadata.items()):
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with pytest.raises(TypeError, match=r"`LoraConfig` class could not be instantiated"):
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
identifier = uuid.uuid4()
|
||||
|
||||
@@ -30,13 +30,13 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class PriorTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class PriorTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = PriorTransformer
|
||||
main_input_name = "hidden_states"
|
||||
|
||||
|
||||
@@ -20,13 +20,13 @@ import torch
|
||||
from diffusers import AuraFlowTransformer2DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class AuraFlowTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = AuraFlowTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
|
||||
@@ -22,7 +22,12 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import (
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PEFTTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -78,7 +83,7 @@ def create_bria_ip_adapter_state_dict(model):
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class BriaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = BriaTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
|
||||
@@ -22,7 +22,12 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import (
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PEFTTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -78,7 +83,7 @@ def create_chroma_ip_adapter_state_dict(model):
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class ChromaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = ChromaTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
|
||||
@@ -19,17 +19,14 @@ import torch
|
||||
|
||||
from diffusers import CogVideoXTransformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class CogVideoXTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = CogVideoXTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -19,13 +19,13 @@ import torch
|
||||
from diffusers import CogView4Transformer2DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class CogView3PlusTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = CogView4Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -19,17 +19,14 @@ import torch
|
||||
|
||||
from diffusers import ConsisIDTransformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class ConsisIDTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class ConsisIDTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = ConsisIDTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -22,7 +22,12 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
|
||||
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import (
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PEFTTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -80,7 +85,7 @@ def create_flux_ip_adapter_state_dict(model):
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class FluxTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
|
||||
@@ -19,17 +19,14 @@ import torch
|
||||
|
||||
from diffusers import HiDreamImageTransformer2DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class HiDreamTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = HiDreamImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@@ -18,17 +18,14 @@ import torch
|
||||
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -18,17 +18,14 @@ import torch
|
||||
|
||||
from diffusers import HunyuanVideoFramepackTransformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoFramepackTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -20,13 +20,13 @@ import torch
|
||||
from diffusers import LTXVideoTransformer3DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class LTXTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = LTXVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -19,17 +19,14 @@ import torch
|
||||
|
||||
from diffusers import Lumina2Transformer2DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = Lumina2Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -20,13 +20,13 @@ import torch
|
||||
from diffusers import MochiTransformer3DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class MochiTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = MochiTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -21,13 +21,13 @@ import torch
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class QwenImageTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
|
||||
@@ -18,17 +18,14 @@ import torch
|
||||
|
||||
from diffusers import SanaTransformer2DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class SanaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = SanaTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -24,13 +24,13 @@ from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class SD3TransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@@ -18,17 +18,14 @@ import torch
|
||||
|
||||
from diffusers import SkyReelsV2Transformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SkyReelsV2Transformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class SkyReelsV2Transformer3DTests(ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = SkyReelsV2Transformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -18,17 +18,14 @@ import torch
|
||||
|
||||
from diffusers import WanTransformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
class WanTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -55,6 +55,7 @@ from ...testing_utils import (
|
||||
from ..test_modeling_common import (
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PEFTTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
UNetTesterMixin,
|
||||
)
|
||||
@@ -354,7 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
# We override the items here because the unet under consideration is small.
|
||||
@@ -1083,48 +1084,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -30,7 +30,7 @@ from ...testing_utils import (
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class UNetMotionModelTests(ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetMotionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user