Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
b397cf685e fix: lumina2 lora fuse_nan test 2025-02-26 20:14:02 +05:30

View File

@@ -15,6 +15,8 @@
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GemmaForCausalLM
@@ -24,12 +26,12 @@ from diffusers import (
Lumina2Text2ImgPipeline,
Lumina2Transformer2DModel,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend
@@ -130,3 +132,41 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=False,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())