Compare commits

...

31 Commits

Author SHA1 Message Date
Sayak Paul
db250958c5 Merge branch 'main' into migrate-lora-pytest 2025-12-10 12:36:41 +08:00
sayakpaul
f956ba0db1 resolve conflicts. 2025-12-04 20:07:15 +08:00
sayakpaul
f3593a8aa9 up 2025-12-03 17:42:36 +08:00
Sayak Paul
1b6cdea043 Merge branch 'main' into migrate-lora-pytest 2025-12-03 17:39:35 +08:00
Sayak Paul
3fb66f23ac Merge branch 'main' into migrate-lora-pytest 2025-11-20 10:13:01 +05:30
sayakpaul
9c3bed1783 up 2025-11-20 10:12:31 +05:30
Sayak Paul
11b80d09b0 Merge branch 'main' into migrate-lora-pytest 2025-11-10 13:27:10 +05:30
Sayak Paul
9201505554 Merge branch 'main' into migrate-lora-pytest 2025-11-06 10:39:44 +05:30
sayakpaul
eece7120dd up 2025-11-06 10:31:37 +05:30
Sayak Paul
2e42205c3a Merge branch 'main' into migrate-lora-pytest 2025-11-06 10:24:51 +05:30
Sayak Paul
757bbf7b05 Merge branch 'main' into migrate-lora-pytest 2025-10-24 22:24:15 +05:30
Sayak Paul
4561c065aa Merge branch 'main' into migrate-lora-pytest 2025-10-17 19:29:40 +05:30
Sayak Paul
4ae5772fef Merge branch 'main' into migrate-lora-pytest 2025-10-17 07:55:31 +05:30
sayakpaul
0d3da485a0 up 2025-10-03 21:00:05 +05:30
sayakpaul
4f5e9a665e up 2025-10-03 20:49:50 +05:30
Sayak Paul
23e5559c54 Merge branch 'main' into migrate-lora-pytest 2025-10-03 20:44:52 +05:30
sayakpaul
f8f27891c6 up 2025-10-03 20:14:45 +05:30
sayakpaul
128535cfcd up 2025-10-03 20:03:50 +05:30
sayakpaul
bdc9537999 more fixtures. 2025-10-03 20:01:26 +05:30
sayakpaul
dae161ed26 up 2025-10-03 17:39:55 +05:30
sayakpaul
c4bcf72084 up 2025-10-03 16:56:31 +05:30
sayakpaul
1737b710a2 up 2025-10-03 16:45:04 +05:30
sayakpaul
565d674cc4 change flux lora integration tests to use pytest 2025-10-03 16:30:58 +05:30
sayakpaul
610842af1a up 2025-10-03 16:14:36 +05:30
sayakpaul
cba82591e8 up 2025-10-03 15:56:37 +05:30
sayakpaul
949cc1c326 up 2025-10-03 14:54:23 +05:30
sayakpaul
ec866f5de8 tempfile is now a fixture. 2025-10-03 14:25:54 +05:30
sayakpaul
7b4bcce602 up 2025-10-03 14:10:31 +05:30
sayakpaul
d61bb38fb4 up 2025-10-03 13:14:05 +05:30
sayakpaul
9e92f6bb63 up 2025-10-03 12:53:37 +05:30
sayakpaul
6c6cade1a7 migrate lora pipeline tests to pytest 2025-10-03 12:52:56 +05:30
17 changed files with 1132 additions and 1763 deletions

View File

@@ -13,16 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import pytest
import torch import torch
from transformers import AutoTokenizer, UMT5EncoderModel from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import ( from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from ..testing_utils import ( from ..testing_utils import (
floats_tensor, floats_tensor,
@@ -40,7 +36,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestAuraFlowLoRA(PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -103,34 +99,34 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in AuraFlow.") @pytest.mark.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in AuraFlow.") @pytest.mark.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in AuraFlow.") @pytest.mark.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.") @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.") @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.") @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.") @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.") @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass

View File

@@ -13,10 +13,9 @@
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import pytest
import torch import torch
from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
@@ -39,7 +38,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestCogVideoXLoRA(PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"} scheduler_kwargs = {"timestep_spacing": "trailing"}
@@ -119,54 +118,59 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3, pipe=pipe)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
def test_lora_scale_kwargs_match_fusion(self): def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) super().test_lora_scale_kwargs_match_fusion(
base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3
)
@parameterized.expand([("block_level", True), ("leaf_level", False)]) @pytest.mark.parametrize(
"offload_type, use_stream",
[("block_level", True), ("leaf_level", False)],
)
@require_torch_accelerator @require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream): def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models. # TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream) super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
@unittest.skip("Not supported in CogVideoX.") @pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in CogVideoX.") @pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in CogVideoX.") @pytest.mark.skip("Not supported in CogVideoX.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.") @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.") @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.") @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.") @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.") @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass
@unittest.skip("Not supported in CogVideoX.") @pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass pass

View File

@@ -13,12 +13,9 @@
# limitations under the License. # limitations under the License.
import sys import sys
import tempfile
import unittest
import numpy as np import pytest
import torch import torch
from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
@@ -28,7 +25,6 @@ from ..testing_utils import (
require_peft_backend, require_peft_backend,
require_torch_accelerator, require_torch_accelerator,
skip_mps, skip_mps,
torch_device,
) )
@@ -47,7 +43,7 @@ class TokenizerWrapper:
@require_peft_backend @require_peft_backend
@skip_mps @skip_mps
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestCogView4LoRA(PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -113,72 +109,50 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_save_pretrained(self): @pytest.mark.parametrize(
""" "offload_type, use_stream",
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained [("block_level", True), ("leaf_level", False)],
""" )
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator @require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream): def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models. # TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream) super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
@unittest.skip("Not supported in CogView4.") @pytest.mark.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in CogView4.") @pytest.mark.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in CogView4.") @pytest.mark.skip("Not supported in CogView4.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.") @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.") @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.") @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.") @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.") @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass

View File

@@ -16,13 +16,11 @@ import copy
import gc import gc
import os import os
import sys import sys
import tempfile
import unittest
import numpy as np import numpy as np
import pytest
import safetensors.torch import safetensors.torch
import torch import torch
from parameterized import parameterized
from PIL import Image from PIL import Image
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -46,14 +44,12 @@ from ..testing_utils import (
if is_peft_available(): if is_peft_available():
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
sys.path.append(".") sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend @require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestFluxLoRA(PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -115,165 +111,134 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_with_alpha_in_state_dict(self): def test_with_alpha_in_state_dict(self, tmpdirname, pipe):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config) pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# modify the state dict to have alpha values following # modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file( state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
) )
alpha_dict = {} alpha_dict = {}
for k, v in state_dict_with_alpha.items(): for k, v in state_dict_with_alpha.items():
# only do for `transformer` and for the k projections -- should be enough to test. if "transformer" in k and "to_k" in k and ("lora_A" in k):
if "transformer" in k and "to_k" in k and "lora_A" in k: alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) state_dict_with_alpha.update(alpha_dict)
state_dict_with_alpha.update(alpha_dict)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict_with_alpha) pipe.load_lora_weights(state_dict_with_alpha)
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
self.assertTrue( "Loading from saved checkpoints should give same results."
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
) )
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001)
def test_lora_expansion_works_for_absent_keys(self): def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname, pipe):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the second LoRA we will load. # Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder") modified_denoiser_lora_config.target_modules.add("x_embedder")
pipe.transformer.add_adapter(modified_denoiser_lora_config) pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse( assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), "LoRA should lead to different results."
"LoRA should lead to different results.",
) )
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
with tempfile.TemporaryDirectory() as tmpdirname: assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) pipe.unload_lora_weights()
pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
pipe.set_adapters(["one", "two"]) pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
assert not np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
self.assertFalse( "Different LoRAs should lead to different results."
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
) )
self.assertFalse( assert not np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), "LoRA should lead to different results."
"LoRA should lead to different results.",
) )
def test_lora_expansion_works_for_extra_keys(self): def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname, pipe):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder") modified_denoiser_lora_config.target_modules.add("x_embedder")
pipe.transformer.add_adapter(modified_denoiser_lora_config) pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse( assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), "LoRA should lead to different results."
"LoRA should lead to different results.",
) )
with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) pipe.unload_lora_weights()
pipe.unload_lora_weights() lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# Modify the state dict to exclude "x_embedder" related LoRA params. lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
# Load state dict with `x_embedder`.
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.set_adapters(["one", "two"]) pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
assert not np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
self.assertFalse( "Different LoRAs should lead to different results."
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
) )
self.assertFalse( assert not np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), "LoRA should lead to different results."
"LoRA should lead to different results.",
) )
@unittest.skip("Not supported in Flux.") @pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Flux.") @pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in Flux.") @pytest.mark.skip("Not supported in Flux.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Not supported in Flux.") @pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass pass
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -338,12 +303,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_with_norm_in_state_dict(self): def test_with_norm_in_state_dict(self, pipe):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger = logging.get_logger("diffusers.loaders.lora_pipeline")
@@ -364,39 +324,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.load_lora_weights(norm_state_dict) pipe.load_lora_weights(norm_state_dict)
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( assert (
"The provided state dict contains normalization layers in addition to LoRA layers" "The provided state dict contains normalization layers in addition to LoRA layers"
in cap_logger.out in cap_logger.out
) )
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0) assert len(pipe.transformer._transformer_norm_layers) > 0
pipe.unload_lora_weights() pipe.unload_lora_weights()
lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(pipe.transformer._transformer_norm_layers is None) assert pipe.transformer._transformer_norm_layers is None
self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5)) assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05)
self.assertFalse( assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), (
np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested" f"{norm_layer} is tested"
) )
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
for key in list(norm_state_dict.keys()): for key in list(norm_state_dict.keys()):
norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key)
pipe.load_lora_weights(norm_state_dict) pipe.load_lora_weights(norm_state_dict)
assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
self.assertTrue( def test_lora_parameter_expanded_shapes(self, pipe):
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
)
def test_lora_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@@ -405,24 +358,21 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
transformer = FluxTransformer2DModel.from_config( transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device) ).to(torch_device)
self.assertTrue( assert transformer.config.in_channels == num_channels_without_control, (
transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
) )
original_transformer_state_dict = pipe.transformer.state_dict() original_transformer_state_dict = pipe.transformer.state_dict()
x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
self.assertTrue( assert "x_embedder.weight" in incompatible_keys.missing_keys, (
"x_embedder.weight" in incompatible_keys.missing_keys, "Could not find x_embedder.weight in the missing keys."
"Could not find x_embedder.weight in the missing keys.",
) )
transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
pipe.transformer = transformer pipe.transformer = transformer
out_features, in_features = pipe.transformer.x_embedder.weight.shape out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4 rank = 4
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = { lora_state_dict = {
@@ -431,15 +381,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
} }
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1") pipe.load_lora_weights(lora_state_dict, "adapter-1")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) assert pipe.transformer.config.in_channels == 2 * in_features
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
# Testing opposite direction where the LoRA params are zero-padded. # Testing opposite direction where the LoRA params are zero-padded.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -454,15 +402,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
} }
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1") pipe.load_lora_weights(lora_state_dict, "adapter-1")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) assert pipe.transformer.config.in_channels == 2 * in_features
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
def test_normal_lora_with_expanded_lora_raises_error(self): def test_normal_lora_with_expanded_lora_raises_error(self):
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
@@ -494,32 +440,28 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
} }
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1") pipe.load_lora_weights(lora_state_dict, "adapter-1")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert pipe.get_active_adapters() == ["adapter-1"]
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"]) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) assert pipe.transformer.config.in_channels == 2 * in_features
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) (_, _, inputs) = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = { lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
} }
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-2") pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"]) assert pipe.get_active_adapters() == ["adapter-2"]
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible. # This should raise a runtime error on input shapes being incompatible.
@@ -540,32 +482,24 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
out_features, in_features = pipe.transformer.x_embedder.weight.shape out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4 rank = 4
lora_state_dict = { lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
} }
pipe.load_lora_weights(lora_state_dict, "adapter-1") pipe.load_lora_weights(lora_state_dict, "adapter-1")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) assert pipe.transformer.config.in_channels == in_features
self.assertTrue(pipe.transformer.config.in_channels == in_features)
lora_state_dict = { lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
} }
# We should check for input shapes being incompatible here. But because above mentioned issue is # We should check for input shapes being incompatible here. But because above mentioned issue is
# not a supported use case, and because of the PEFT renaming, we will currently have a shape # not a supported use case, and because of the PEFT renaming, we will currently have a shape
# mismatch error. # mismatch error.
self.assertRaisesRegex( with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"):
RuntimeError, pipe.load_lora_weights(lora_state_dict, "adapter-2")
"size mismatch for x_embedder.lora_A.adapter-2.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)
def test_fuse_expanded_lora_with_regular_lora(self): def test_fuse_expanded_lora_with_regular_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but # This test checks if it works when a lora with expanded shapes (like control loras) but
@@ -597,7 +531,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
} }
pipe.load_lora_weights(lora_state_dict, "adapter-1") pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -610,54 +544,44 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
} }
pipe.load_lora_weights(lora_state_dict, "adapter-2") pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001)
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3)) assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001)
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3)) assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001)
def test_load_regular_lora(self): def test_load_regular_lora(self, base_pipe_output, pipe):
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
# transformers include Flux Fill, Flux Control, etc. # transformers include Flux Fill, Flux Control, etc.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
out_features, in_features = pipe.transformer.x_embedder.weight.shape out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4 rank = 4
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA. in_features = in_features // 2
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = { lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
} }
logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1") pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
def test_lora_unload_with_parameter_expanded_shapes(self): def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -670,9 +594,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
transformer = FluxTransformer2DModel.from_config( transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device) ).to(torch_device)
self.assertTrue( assert transformer.config.in_channels == num_channels_without_control, (
transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
) )
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
@@ -697,33 +620,31 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
} }
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1") control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
inputs["control_image"] = control_image inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) assert pipe.transformer.config.in_channels == 2 * in_features
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
control_pipe.unload_lora_weights(reset_to_overwritten_params=True) control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
self.assertTrue( assert control_pipe.transformer.config.in_channels == num_channels_without_control, (
control_pipe.transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
) )
loaded_pipe = FluxPipeline.from_pipe(control_pipe) loaded_pipe = FluxPipeline.from_pipe(control_pipe)
self.assertTrue( assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, (
loaded_pipe.transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}"
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
) )
inputs.pop("control_image") inputs.pop("control_image")
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001)
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4)) assert np.allclose(unloaded_lora_out, original_out, atol=0.0001, rtol=0.0001)
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4)) assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) assert pipe.transformer.config.in_channels == in_features
self.assertTrue(pipe.transformer.config.in_channels == in_features)
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -731,14 +652,12 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4 num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config( transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device) ).to(torch_device)
self.assertTrue( assert transformer.config.in_channels == num_channels_without_control, (
transformer.config.in_channels == num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
) )
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
@@ -763,40 +682,38 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
} }
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1") control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
inputs["control_image"] = control_image inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) assert pipe.transformer.config.in_channels == 2 * in_features
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
control_pipe.unload_lora_weights(reset_to_overwritten_params=False) control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
self.assertTrue( assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, (
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
) )
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
assert pipe.transformer.config.in_channels == in_features * 2
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4)) @pytest.mark.skip("Not supported in Flux.")
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Flux.") @pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in Flux.") @pytest.mark.skip("Not supported in Flux.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Not supported in Flux.") @pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass pass
@@ -806,7 +723,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator @require_torch_accelerator
@require_peft_backend @require_peft_backend
@require_big_accelerator @require_big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase): class TestFluxLoRAIntegration:
"""internal note: The integration slices were obtained on audace. """internal note: The integration slices were obtained on audace.
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
@@ -816,33 +733,27 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10 num_inference_steps = 10
seed = 0 seed = 0
def setUp(self): @pytest.fixture(scope="function")
super().setUp() def pipeline(self):
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
torch_device
)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) def test_flux_the_last_ben(self, pipeline):
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
def tearDown(self): pipeline.fuse_lora()
super().tearDown() pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
del self.pipeline
gc.collect()
backend_empty_cache(torch_device)
def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
self.pipeline = self.pipeline.to(torch_device)
prompt = "jon snow eating pizza with ketchup" prompt = "jon snow eating pizza with ketchup"
out = pipeline(
out = self.pipeline(
prompt, prompt,
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
guidance_scale=4.0, guidance_scale=4.0,
@@ -851,71 +762,57 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
).images ).images
out_slice = out[0, -3:, -3:, -1].flatten() out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246]) expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3 def test_flux_kohya(self, pipeline):
pipeline.load_lora_weights("Norod78/brain-slug-flux")
def test_flux_kohya(self): pipeline.fuse_lora()
self.pipeline.load_lora_weights("Norod78/brain-slug-flux") pipeline.unload_lora_weights()
self.pipeline.fuse_lora() pipeline = pipeline.to(torch_device)
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
prompt = "The cat with a brain slug earring" prompt = "The cat with a brain slug earring"
out = self.pipeline( out = pipeline(
prompt, prompt,
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
guidance_scale=4.5, guidance_scale=4.5,
output_type="np", output_type="np",
generator=torch.manual_seed(self.seed), generator=torch.manual_seed(self.seed),
).images ).images
out_slice = out[0, -3:, -3:, -1].flatten() out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484]) expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3 def test_flux_kohya_with_text_encoder(self, pipeline):
pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
def test_flux_kohya_with_text_encoder(self): pipeline.fuse_lora()
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") pipeline.unload_lora_weights()
self.pipeline.fuse_lora() pipeline = pipeline.to(torch_device)
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
prompt = "optimus is cleaning the house with broomstick" prompt = "optimus is cleaning the house with broomstick"
out = self.pipeline( out = pipeline(
prompt, prompt,
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
guidance_scale=4.5, guidance_scale=4.5,
output_type="np", output_type="np",
generator=torch.manual_seed(self.seed), generator=torch.manual_seed(self.seed),
).images ).images
out_slice = out[0, -3:, -3:, -1].flatten() out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219]) expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3 def test_flux_kohya_embedders_conversion(self, pipeline):
def test_flux_kohya_embedders_conversion(self):
"""Test that embedders load without throwing errors""" """Test that embedders load without throwing errors"""
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
self.pipeline.unload_lora_weights() pipeline.unload_lora_weights()
assert True
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_xlabs(self, pipeline):
pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
out = pipeline(
out = self.pipeline(
prompt, prompt,
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
guidance_scale=3.5, guidance_scale=3.5,
@@ -923,23 +820,17 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
generator=torch.manual_seed(self.seed), generator=torch.manual_seed(self.seed),
).images ).images
out_slice = out[0, -3:, -3:, -1].flatten() out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980]) expected_slice = np.array([0.3965, 0.418, 0.4434, 0.4082, 0.4375, 0.459, 0.4141, 0.4375, 0.498])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3 def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline):
pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
def test_flux_xlabs_load_lora_with_single_blocks(self): pipeline.fuse_lora()
self.pipeline.load_lora_weights( pipeline.unload_lora_weights()
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors" pipeline.enable_model_cpu_offload()
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()
prompt = "a wizard mouse playing chess" prompt = "a wizard mouse playing chess"
out = pipeline(
out = self.pipeline(
prompt, prompt,
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
guidance_scale=3.5, guidance_scale=3.5,
@@ -951,40 +842,43 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
[0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625] [0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
) )
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
@nightly @nightly
@require_torch_accelerator @require_torch_accelerator
@require_peft_backend @require_peft_backend
@require_big_accelerator @require_big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase): class TestFluxControlLoRAIntegration:
num_inference_steps = 10 num_inference_steps = 10
seed = 0 seed = 0
prompt = "A robot made of exotic candies and chocolates of different kinds." prompt = "A robot made of exotic candies and chocolates of different kinds."
def setUp(self): @pytest.fixture(scope="function")
super().setUp() def pipeline(self):
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
torch_device
)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
self.pipeline = FluxControlPipeline.from_pretrained( @pytest.mark.parametrize(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 "lora_ckpt_id",
).to(torch_device) [
"black-forest-labs/FLUX.1-Canny-dev-lora",
def tearDown(self): "black-forest-labs/FLUX.1-Depth-dev-lora",
super().tearDown() ],
)
gc.collect() def test_lora(self, pipeline, lora_ckpt_id):
backend_empty_cache(torch_device) pipeline.load_lora_weights(lora_ckpt_id)
pipeline.fuse_lora()
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) pipeline.unload_lora_weights()
def test_lora(self, lora_ckpt_id):
self.pipeline.load_lora_weights(lora_ckpt_id)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id: if "Canny" in lora_ckpt_id:
control_image = load_image( control_image = load_image(
@@ -995,7 +889,7 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
) )
image = self.pipeline( image = pipeline(
prompt=self.prompt, prompt=self.prompt,
control_image=control_image, control_image=control_image,
height=1024, height=1024,
@@ -1016,12 +910,18 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
assert max_diff < 1e-3 assert max_diff < 1e-3
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) @pytest.mark.parametrize(
def test_lora_with_turbo(self, lora_ckpt_id): "lora_ckpt_id",
self.pipeline.load_lora_weights(lora_ckpt_id) [
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") "black-forest-labs/FLUX.1-Canny-dev-lora",
self.pipeline.fuse_lora() "black-forest-labs/FLUX.1-Depth-dev-lora",
self.pipeline.unload_lora_weights() ],
)
def test_lora_with_turbo(self, pipeline, lora_ckpt_id):
pipeline.load_lora_weights(lora_ckpt_id)
pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id: if "Canny" in lora_ckpt_id:
control_image = load_image( control_image = load_image(

View File

@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import numpy as np import numpy as np
import pytest
import torch import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration from transformers import AutoProcessor, Mistral3ForConditionalGeneration
@@ -30,7 +30,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend @require_peft_backend
class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestFlux2LoRA(PeftLoraLoaderMixinTests):
pipeline_class = Flux2Pipeline pipeline_class = Flux2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -133,36 +133,36 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0] out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all()) assert np.isnan(out).all()
@unittest.skip("Not supported in Flux2.") @pytest.mark.skip("Not supported in Flux2.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Flux2.") @pytest.mark.skip("Not supported in Flux2.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in Flux2.") @pytest.mark.skip("Not supported in Flux2.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.") @pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.") @pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.") @pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.") @pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.") @pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass

View File

@@ -14,9 +14,9 @@
import gc import gc
import sys import sys
import unittest
import numpy as np import numpy as np
import pytest
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -48,7 +48,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
@skip_mps @skip_mps
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -149,46 +149,41 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
# TODO(aryan): Fix the following test @pytest.mark.skip("Not supported in HunyuanVideo.")
@unittest.skip("This test fails with an error I haven't been able to debug yet.")
def test_simple_inference_save_pretrained(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in HunyuanVideo.") @pytest.mark.skip("Not supported in HunyuanVideo.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in HunyuanVideo.") @pytest.mark.skip("Not supported in HunyuanVideo.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass
@@ -197,7 +192,7 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator @require_torch_accelerator
@require_peft_backend @require_peft_backend
@require_big_accelerator @require_big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase): class TestHunyuanVideoLoRAIntegration:
"""internal note: The integration slices were obtained on DGX. """internal note: The integration slices were obtained on DGX.
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
@@ -207,9 +202,8 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10 num_inference_steps = 10
seed = 0 seed = 0
def setUp(self): @pytest.fixture(scope="function")
super().setUp() def pipeline(self):
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
@@ -217,27 +211,27 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
transformer = HunyuanVideoTransformer3DModel.from_pretrained( transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16 model_id, subfolder="transformer", torch_dtype=torch.bfloat16
) )
self.pipeline = HunyuanVideoPipeline.from_pretrained( pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to(
model_id, transformer=transformer, torch_dtype=torch.float16 torch_device
).to(torch_device) )
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self): def test_original_format_cseti(self, pipeline):
super().tearDown() pipeline.load_lora_weights(
gc.collect()
backend_empty_cache(torch_device)
def test_original_format_cseti(self):
self.pipeline.load_lora_weights(
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors" "Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
) )
self.pipeline.fuse_lora() pipeline.fuse_lora()
self.pipeline.unload_lora_weights() pipeline.unload_lora_weights()
self.pipeline.vae.enable_tiling() pipeline.vae.enable_tiling()
prompt = "CSETIARCANE. A cat walks on the grass, realistic" prompt = "CSETIARCANE. A cat walks on the grass, realistic"
out = self.pipeline( out = pipeline(
prompt=prompt, prompt=prompt,
height=320, height=320,
width=512, width=512,

View File

@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import pytest
import torch import torch
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestLTXVideoLoRA(PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -108,40 +108,40 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in LTXVideo.") @pytest.mark.skip("Not supported in LTXVideo.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in LTXVideo.") @pytest.mark.skip("Not supported in LTXVideo.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in LTXVideo.") @pytest.mark.skip("Not supported in LTXVideo.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.") @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass

View File

@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import numpy as np import numpy as np
import pytest import pytest
@@ -36,7 +35,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend @require_peft_backend
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestLumina2LoRA(PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -101,35 +100,35 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Lumina2.") @pytest.mark.skip("Not supported in Lumina2.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Lumina2.") @pytest.mark.skip("Not supported in Lumina2.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in Lumina2.") @pytest.mark.skip("Not supported in Lumina2.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.") @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.") @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.") @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.") @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.") @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass
@@ -139,20 +138,17 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=False, strict=False,
) )
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self, pipe):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
# corrupt one LoRA weight with `inf` values # corrupt one LoRA weight with `inf` values
with torch.no_grad(): with torch.no_grad():
@@ -166,4 +162,4 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0] out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all()) assert np.isnan(out).all()

View File

@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import pytest
import torch import torch
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
@skip_mps @skip_mps
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestMochiLoRA(PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -99,44 +99,44 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Mochi.") @pytest.mark.skip("Not supported in Mochi.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Mochi.") @pytest.mark.skip("Not supported in Mochi.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in Mochi.") @pytest.mark.skip("Not supported in Mochi.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.") @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass
@unittest.skip("Not supported in CogVideoX.") @pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass pass

View File

@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import pytest
import torch import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestQwenImageLoRA(PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -96,34 +96,34 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Qwen Image.") @pytest.mark.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Qwen Image.") @pytest.mark.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in Qwen Image.") @pytest.mark.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.") @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.") @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.") @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.") @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.") @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass

View File

@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import pytest
import torch import torch
from transformers import Gemma2Model, GemmaTokenizer from transformers import Gemma2Model, GemmaTokenizer
@@ -29,7 +29,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestSanaLoRA(PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {"shift": 7.0} scheduler_kwargs = {"shift": 7.0}
@@ -105,38 +105,38 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in SANA.") @pytest.mark.skip("Not supported in SANA.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Not supported in SANA.") @pytest.mark.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in SANA.") @pytest.mark.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in SANA.") @pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in SANA.") @pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in SANA.") @pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in SANA.") @pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in SANA.") @pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") @pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference_denoiser(self): def test_layerwise_casting_inference_denoiser(self):
return super().test_layerwise_casting_inference_denoiser() return super().test_layerwise_casting_inference_denoiser()

View File

@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
import gc import gc
import sys import sys
import unittest
import numpy as np import numpy as np
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@@ -55,7 +55,7 @@ if is_accelerate_available():
from accelerate.utils import release_memory from accelerate.utils import release_memory
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
scheduler_cls = DDIMScheduler scheduler_cls = DDIMScheduler
scheduler_kwargs = { scheduler_kwargs = {
@@ -91,16 +91,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def output_shape(self): def output_shape(self):
return (1, 64, 64, 3) return (1, 64, 64, 3)
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
# Keeping this test here makes sense because it doesn't look any integration # Keeping this test here makes sense because it doesn't look any integration
# (value assertions on logits). # (value assertions on logits).
@slow @slow
@@ -114,15 +104,8 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.load_lora_weights(lora_id, adapter_name="adapter-2") pipe.load_lora_weights(lora_id, adapter_name="adapter-2")
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
self.assertTrue( assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
check_if_lora_correctly_set(pipe.text_encoder), assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
"Lora not correctly set in text encoder",
)
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
# We will offload the first adapter in CPU and check if the offloading # We will offload the first adapter in CPU and check if the offloading
# has been performed correctly # has been performed correctly
@@ -130,35 +113,35 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
for name, module in pipe.unet.named_modules(): for name, module in pipe.unet.named_modules():
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu")) assert module.weight.device == torch.device("cpu")
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)): elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu")) assert module.weight.device != torch.device("cpu")
for name, module in pipe.text_encoder.named_modules(): for name, module in pipe.text_encoder.named_modules():
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu")) assert module.weight.device == torch.device("cpu")
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)): elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu")) assert module.weight.device != torch.device("cpu")
pipe.set_lora_device(["adapter-1"], 0) pipe.set_lora_device(["adapter-1"], 0)
for n, m in pipe.unet.named_modules(): for n, m in pipe.unet.named_modules():
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)): if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu")) assert m.weight.device != torch.device("cpu")
for n, m in pipe.text_encoder.named_modules(): for n, m in pipe.text_encoder.named_modules():
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)): if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu")) assert m.weight.device != torch.device("cpu")
pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device) pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device)
for n, m in pipe.unet.named_modules(): for n, m in pipe.unet.named_modules():
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu")) assert m.weight.device != torch.device("cpu")
for n, m in pipe.text_encoder.named_modules(): for n, m in pipe.text_encoder.named_modules():
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu")) assert m.weight.device != torch.device("cpu")
@slow @slow
@require_torch_accelerator @require_torch_accelerator
@@ -181,15 +164,9 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.unet.add_adapter(unet_lora_config, "adapter-1") pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
self.assertTrue( assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
for name, param in pipe.unet.named_parameters(): for name, param in pipe.unet.named_parameters():
if "lora_" in name: if "lora_" in name:
@@ -225,17 +202,14 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.unet.add_adapter(config1, adapter_name="adapter-1") pipe.unet.add_adapter(config1, adapter_name="adapter-1")
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
self.assertTrue( assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")} modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")} modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
self.assertNotEqual(modules_adapter_0, modules_adapter_1) assert modules_adapter_0 != modules_adapter_1
self.assertTrue(modules_adapter_0 - modules_adapter_1) assert modules_adapter_0 - modules_adapter_1
self.assertTrue(modules_adapter_1 - modules_adapter_0) assert modules_adapter_1 - modules_adapter_0
# setting both separately works # setting both separately works
pipe.set_lora_device(["adapter-0"], "cpu") pipe.set_lora_device(["adapter-0"], "cpu")
@@ -243,32 +217,30 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
for name, module in pipe.unet.named_modules(): for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu")) assert module.weight.device == torch.device("cpu")
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu")) assert module.weight.device == torch.device("cpu")
# setting both at once also works # setting both at once also works
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device) pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
for name, module in pipe.unet.named_modules(): for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu")) assert module.weight.device != torch.device("cpu")
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu")) assert module.weight.device != torch.device("cpu")
@slow @slow
@nightly @nightly
@require_torch_accelerator @require_torch_accelerator
@require_peft_backend @require_peft_backend
class LoraIntegrationTests(unittest.TestCase): class TestSDLoraIntegration:
def setUp(self): @pytest.fixture(autouse=True)
super().setUp() def _gc_and_cache_cleanup(self, torch_device):
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
yield
def tearDown(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
@@ -280,10 +252,7 @@ class LoraIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id) pipe.load_lora_weights(lora_id)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
self.assertTrue( assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
prompt = "a red sks dog" prompt = "a red sks dog"
@@ -312,10 +281,7 @@ class LoraIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id) pipe.load_lora_weights(lora_id)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
self.assertTrue( assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
prompt = "a red sks dog" prompt = "a red sks dog"
@@ -587,8 +553,8 @@ class LoraIntegrationTests(unittest.TestCase):
).images ).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images)) assert not np.allclose(initial_images, lora_images)
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
release_memory(pipe) release_memory(pipe)
@@ -625,8 +591,8 @@ class LoraIntegrationTests(unittest.TestCase):
).images ).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images)) assert not np.allclose(initial_images, lora_images)
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
# make sure we can load a LoRA again after unloading and they don't have # make sure we can load a LoRA again after unloading and they don't have
# any undesired effects. # any undesired effects.
@@ -637,7 +603,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images ).images
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten() lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) assert np.allclose(lora_images, lora_images_again, atol=1e-3)
release_memory(pipe) release_memory(pipe)
def test_not_empty_state_dict(self): def test_not_empty_state_dict(self):
@@ -651,7 +617,7 @@ class LoraIntegrationTests(unittest.TestCase):
lcm_lora = load_file(cached_file) lcm_lora = load_file(cached_file)
pipe.load_lora_weights(lcm_lora, adapter_name="lcm") pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertTrue(lcm_lora != {}) assert lcm_lora != {}
release_memory(pipe) release_memory(pipe)
def test_load_unload_load_state_dict(self): def test_load_unload_load_state_dict(self):
@@ -666,11 +632,11 @@ class LoraIntegrationTests(unittest.TestCase):
previous_state_dict = lcm_lora.copy() previous_state_dict = lcm_lora.copy()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm") pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict) assert lcm_lora == previous_state_dict
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm") pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict) assert lcm_lora == previous_state_dict
release_memory(pipe) release_memory(pipe)

View File

@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
import gc import gc
import sys import sys
import unittest
import numpy as np import numpy as np
import pytest
import torch import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -51,7 +51,7 @@ if is_accelerate_available():
@require_peft_backend @require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestSD3LoRA(PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -113,19 +113,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
lora_filename = "lora_peft_format.safetensors" lora_filename = "lora_peft_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@unittest.skip("Not supported in SD3.") @pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in SD3.") @pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass pass
@unittest.skip("Not supported in SD3.") @pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in SD3.") @pytest.mark.skip("Not supported in SD3.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@@ -138,17 +138,15 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator @require_torch_accelerator
@require_peft_backend @require_peft_backend
@require_big_accelerator @require_big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase): class TestSD3LoraIntegration:
pipeline_class = StableDiffusion3Img2ImgPipeline pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
def setUp(self): @pytest.fixture(autouse=True)
super().setUp() def _gc_and_cache_cleanup(self, torch_device):
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
yield
def tearDown(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)

View File

@@ -17,9 +17,9 @@ import gc
import importlib import importlib
import sys import sys
import time import time
import unittest
import numpy as np import numpy as np
import pytest
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@@ -59,7 +59,7 @@ if is_accelerate_available():
from accelerate.utils import release_memory from accelerate.utils import release_memory
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
has_two_text_encoders = True has_two_text_encoders = True
pipeline_class = StableDiffusionXLPipeline pipeline_class = StableDiffusionXLPipeline
scheduler_cls = EulerDiscreteScheduler scheduler_cls = EulerDiscreteScheduler
@@ -104,21 +104,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def output_shape(self): def output_shape(self):
return (1, 64, 64, 3) return (1, 64, 64, 3)
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@is_flaky @is_flaky
def test_multiple_wrong_adapter_name_raises_error(self): def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error() super().test_multiple_wrong_adapter_name_raises_error()
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
if torch.cuda.is_available(): if torch.cuda.is_available():
expected_atol = 9e-2 expected_atol = 9e-2
expected_rtol = 9e-2 expected_rtol = 9e-2
@@ -127,10 +117,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_rtol = 1e-3 expected_rtol = 1e-3
super().test_simple_inference_with_text_denoiser_lora_unfused( super().test_simple_inference_with_text_denoiser_lora_unfused(
expected_atol=expected_atol, expected_rtol=expected_rtol pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
) )
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
if torch.cuda.is_available(): if torch.cuda.is_available():
expected_atol = 9e-2 expected_atol = 9e-2
expected_rtol = 9e-2 expected_rtol = 9e-2
@@ -139,10 +129,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_rtol = 1e-3 expected_rtol = 1e-3
super().test_simple_inference_with_text_lora_denoiser_fused_multi( super().test_simple_inference_with_text_lora_denoiser_fused_multi(
expected_atol=expected_atol, expected_rtol=expected_rtol pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
) )
def test_lora_scale_kwargs_match_fusion(self): def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
if torch.cuda.is_available(): if torch.cuda.is_available():
expected_atol = 9e-2 expected_atol = 9e-2
expected_rtol = 9e-2 expected_rtol = 9e-2
@@ -150,21 +140,21 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_atol = 1e-3 expected_atol = 1e-3
expected_rtol = 1e-3 expected_rtol = 1e-3
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol) super().test_lora_scale_kwargs_match_fusion(
base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol
)
@slow @slow
@nightly @nightly
@require_torch_accelerator @require_torch_accelerator
@require_peft_backend @require_peft_backend
class LoraSDXLIntegrationTests(unittest.TestCase): class TestLoraSDXLIntegration:
def setUp(self): @pytest.fixture(autouse=True)
super().setUp() def _gc_and_cache_cleanup(self, torch_device):
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
yield
def tearDown(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
@@ -383,7 +373,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
end_time = time.time() end_time = time.time()
elapsed_time_fusion = end_time - start_time elapsed_time_fusion = end_time - start_time
self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion) assert elapsed_time_fusion < elapsed_time_non_fusion
release_memory(pipe) release_memory(pipe)
@@ -439,14 +429,14 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
for key, value in text_encoder_1_sd.items(): for key, value in text_encoder_1_sd.items():
key = remap_key(key, fused_te_state_dict) key = remap_key(key, fused_te_state_dict)
self.assertTrue(torch.allclose(fused_te_state_dict[key], value)) assert torch.allclose(fused_te_state_dict[key], value)
for key, value in text_encoder_2_sd.items(): for key, value in text_encoder_2_sd.items():
key = remap_key(key, fused_te_2_state_dict) key = remap_key(key, fused_te_2_state_dict)
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value)) assert torch.allclose(fused_te_2_state_dict[key], value)
for key, value in unet_state_dict.items(): for key, value in unet_state_dict.items():
self.assertTrue(torch.allclose(unet_state_dict[key], value)) assert torch.allclose(unet_state_dict[key], value)
pipe.fuse_lora() pipe.fuse_lora()
pipe.unload_lora_weights() pipe.unload_lora_weights()
@@ -589,7 +579,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy") pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet"
prompt = "toy_face of a hacker with a hoodie" prompt = "toy_face of a hacker with a hoodie"

View File

@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
import sys import sys
import unittest
import pytest
import torch import torch
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
@@ -39,7 +39,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
@skip_mps @skip_mps
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestWanLoRA(PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -104,40 +104,40 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Wan.") @pytest.mark.skip("Not supported in Wan.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Wan.") @pytest.mark.skip("Not supported in Wan.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in Wan.") @pytest.mark.skip("Not supported in Wan.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass

View File

@@ -14,10 +14,9 @@
import os import os
import sys import sys
import tempfile
import unittest
import numpy as np import numpy as np
import pytest
import safetensors.torch import safetensors.torch
import torch import torch
from PIL import Image from PIL import Image
@@ -32,7 +31,6 @@ from ..testing_utils import (
require_peft_backend, require_peft_backend,
require_peft_version_greater, require_peft_version_greater,
skip_mps, skip_mps,
torch_device,
) )
@@ -47,7 +45,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
@skip_mps @skip_mps
@is_flaky(max_attempts=10, description="very flaky class") @is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class TestWanVACELoRA(PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
@@ -121,56 +119,51 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Wan VACE.") @pytest.mark.skip("Not supported in Wan VACE.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
@unittest.skip("Not supported in Wan VACE.") @pytest.mark.skip("Not supported in Wan VACE.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass pass
@unittest.skip("Not supported in Wan VACE.") @pytest.mark.skip("Not supported in Wan VACE.")
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
pass pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.") @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass
def test_layerwise_casting_inference_denoiser(self):
super().test_layerwise_casting_inference_denoiser()
@require_peft_version_greater("0.13.2") @require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self): def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname, pipe):
exclude_module_name = "vace_blocks.0.proj_out" exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output() assert base_pipe_output.shape == self.output_shape
self.assertTrue(output_no_lora.shape == self.output_shape)
# only supported for `denoiser` now # only supported for `denoiser` now
denoiser_lora_config.target_modules = ["proj_out"] denoiser_lora_config.target_modules = ["proj_out"]
@@ -180,36 +173,30 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
) )
# The state dict shouldn't contain the modules to be excluded from LoRA. # The state dict shouldn't contain the modules to be excluded from LoRA.
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default") state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model)) assert not any(exclude_module_name in k for k in state_dict_from_model)
self.assertTrue(any("proj_out" in k for k in state_dict_from_model)) assert any("proj_out" in k for k in state_dict_from_model)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save) self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts) pipe.unload_lora_weights()
pipe.unload_lora_weights()
# Check in the loaded state dict. # Check in the loaded state dict.
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict)) assert not any(exclude_module_name in k for k in loaded_state_dict)
self.assertTrue(any("proj_out" in k for k in loaded_state_dict)) assert any("proj_out" in k for k in loaded_state_dict)
# Check in the state dict obtained after loading LoRA. # Check in the state dict obtained after loading LoRA.
pipe.load_lora_weights(tmpdir) pipe.load_lora_weights(tmpdirname)
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0") state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model)) assert not any(exclude_module_name in k for k in state_dict_from_model)
self.assertTrue(any("proj_out" in k for k in state_dict_from_model)) assert any("proj_out" in k for k in state_dict_from_model)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), "LoRA should change outputs."
"LoRA should change outputs.", )
) assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
self.assertTrue( "Lora outputs should match."
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), )
"Lora outputs should match.",
)
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()

File diff suppressed because it is too large Load Diff