mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
1 Commits
fix-timeou
...
fix-lumina
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b397cf685e |
@@ -15,6 +15,8 @@
|
|||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, GemmaForCausalLM
|
from transformers import AutoTokenizer, GemmaForCausalLM
|
||||||
|
|
||||||
@@ -24,12 +26,12 @@ from diffusers import (
|
|||||||
Lumina2Text2ImgPipeline,
|
Lumina2Text2ImgPipeline,
|
||||||
Lumina2Transformer2DModel,
|
Lumina2Transformer2DModel,
|
||||||
)
|
)
|
||||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(".")
|
sys.path.append(".")
|
||||||
|
|
||||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@require_peft_backend
|
@require_peft_backend
|
||||||
@@ -130,3 +132,41 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||||
def test_simple_inference_with_text_lora_save_load(self):
|
def test_simple_inference_with_text_lora_save_load(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@skip_mps
|
||||||
|
@pytest.mark.xfail(
|
||||||
|
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||||
|
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
def test_lora_fuse_nan(self):
|
||||||
|
for scheduler_cls in self.scheduler_classes:
|
||||||
|
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||||
|
|
||||||
|
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||||
|
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||||
|
self.assertTrue(
|
||||||
|
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||||
|
)
|
||||||
|
|
||||||
|
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||||
|
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||||
|
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||||
|
|
||||||
|
# corrupt one LoRA weight with `inf` values
|
||||||
|
with torch.no_grad():
|
||||||
|
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||||
|
|
||||||
|
# with `safe_fusing=True` we should see an Error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||||
|
|
||||||
|
# without we should not see an error, but every image will be black
|
||||||
|
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||||
|
out = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
self.assertTrue(np.isnan(out).all())
|
||||||
|
|||||||
Reference in New Issue
Block a user