Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
93eee19d50 up 2025-09-24 10:35:16 +05:30
23 changed files with 221 additions and 292 deletions

View File

@@ -35,13 +35,13 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2

View File

@@ -88,7 +88,12 @@ from ..testing_utils import (
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_peft_model_state_dict
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin
def caculate_expected_num_shards(index_map_path):
@@ -1113,177 +1118,6 @@ class ModelTesterMixin:
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
self.assertTrue(torch.allclose(loaded_v, retrieved_v))
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_wrong_adapter_name_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with self.assertRaises(ValueError) as err_context:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
# Perturb the metadata in the state dict.
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with self.assertRaises(TypeError) as err_context:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
@require_torch_accelerator
def test_cpu_offload(self):
if self.model_class._no_split_modules is None:
@@ -1941,6 +1775,154 @@ class ModelTesterMixin:
_ = loaded_model(**inputs_dict)
class PEFTTesterMixin:
@require_peft_backend
@pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)
state_dict_loaded = safetensors.torch.load_file(model_file)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k, loaded_v in state_dict_loaded.items():
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert torch.allclose(loaded_v, retrieved_v)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
@require_peft_backend
def test_lora_wrong_adapter_name_raises_error(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with pytest.raises(ValueError, match=rf"Adapter name {wrong_name} not found in the model\."):
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
@require_peft_backend
@pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)])
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@require_peft_backend
def test_lora_adapter_wrong_metadata_raises_error(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)
# Perturb the metadata
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in list(lora_adapter_metadata.items()):
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with pytest.raises(TypeError, match=r"`LoraConfig` class could not be instantiated"):
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()

View File

@@ -30,13 +30,13 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class PriorTransformerTests(ModelTesterMixin, unittest.TestCase):
class PriorTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = PriorTransformer
main_input_name = "hidden_states"

View File

@@ -20,13 +20,13 @@ import torch
from diffusers import AuraFlowTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
class AuraFlowTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = AuraFlowTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.

View File

@@ -22,7 +22,12 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
PEFTTesterMixin,
TorchCompileTesterMixin,
)
enable_full_determinism()
@@ -78,7 +83,7 @@ def create_bria_ip_adapter_state_dict(model):
return ip_state_dict
class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
class BriaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.

View File

@@ -22,7 +22,12 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
PEFTTesterMixin,
TorchCompileTesterMixin,
)
enable_full_determinism()
@@ -78,7 +83,7 @@ def create_chroma_ip_adapter_state_dict(model):
return ip_state_dict
class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
class ChromaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.

View File

@@ -19,17 +19,14 @@ import torch
from diffusers import CogVideoXTransformer3DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
class CogVideoXTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -19,13 +19,13 @@ import torch
from diffusers import CogView4Transformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
class CogView3PlusTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = CogView4Transformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -19,17 +19,14 @@ import torch
from diffusers import ConsisIDTransformer3DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class ConsisIDTransformerTests(ModelTesterMixin, unittest.TestCase):
class ConsisIDTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = ConsisIDTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -22,7 +22,12 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
PEFTTesterMixin,
TorchCompileTesterMixin,
)
enable_full_determinism()
@@ -80,7 +85,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
class FluxTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.

View File

@@ -19,17 +19,14 @@ import torch
from diffusers import HiDreamImageTransformer2DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase):
class HiDreamTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = HiDreamImageTransformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]

View File

@@ -18,17 +18,14 @@ import torch
from diffusers import HunyuanVideoTransformer3DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
class HunyuanVideoTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -18,17 +18,14 @@ import torch
from diffusers import HunyuanVideoFramepackTransformer3DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
class HunyuanVideoTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = HunyuanVideoFramepackTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -20,13 +20,13 @@ import torch
from diffusers import LTXVideoTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
class LTXTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -19,17 +19,14 @@ import torch
from diffusers import Lumina2Transformer2DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = Lumina2Transformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -20,13 +20,13 @@ import torch
from diffusers import MochiTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
class MochiTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = MochiTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -21,13 +21,13 @@ import torch
from diffusers import QwenImageTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
class QwenImageTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.

View File

@@ -18,17 +18,14 @@ import torch
from diffusers import SanaTransformer2DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
class SanaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -24,13 +24,13 @@ from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin
enable_full_determinism()
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
class SD3TransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]

View File

@@ -18,17 +18,14 @@ import torch
from diffusers import SkyReelsV2Transformer3DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class SkyReelsV2Transformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class SkyReelsV2Transformer3DTests(ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
model_class = SkyReelsV2Transformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -18,17 +18,14 @@ import torch
from diffusers import WanTransformer3DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
class WanTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

View File

@@ -55,6 +55,7 @@ from ...testing_utils import (
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
PEFTTesterMixin,
TorchCompileTesterMixin,
UNetTesterMixin,
)
@@ -354,7 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
@@ -1083,48 +1084,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with self.assertWarns(FutureWarning) as warning:
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
warning_message = str(warning.warnings[0].message)
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()

View File

@@ -30,7 +30,7 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class UNetMotionModelTests(ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetMotionModel
main_input_name = "sample"