mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 14:34:55 +08:00
Compare commits
31 Commits
progress-b
...
migrate-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db250958c5 | ||
|
|
f956ba0db1 | ||
|
|
f3593a8aa9 | ||
|
|
1b6cdea043 | ||
|
|
3fb66f23ac | ||
|
|
9c3bed1783 | ||
|
|
11b80d09b0 | ||
|
|
9201505554 | ||
|
|
eece7120dd | ||
|
|
2e42205c3a | ||
|
|
757bbf7b05 | ||
|
|
4561c065aa | ||
|
|
4ae5772fef | ||
|
|
0d3da485a0 | ||
|
|
4f5e9a665e | ||
|
|
23e5559c54 | ||
|
|
f8f27891c6 | ||
|
|
128535cfcd | ||
|
|
bdc9537999 | ||
|
|
dae161ed26 | ||
|
|
c4bcf72084 | ||
|
|
1737b710a2 | ||
|
|
565d674cc4 | ||
|
|
610842af1a | ||
|
|
cba82591e8 | ||
|
|
949cc1c326 | ||
|
|
ec866f5de8 | ||
|
|
7b4bcce602 | ||
|
|
d61bb38fb4 | ||
|
|
9e92f6bb63 | ||
|
|
6c6cade1a7 |
@@ -47,7 +47,6 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -430,12 +429,8 @@ def _load_shard_files_with_threadpool(
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
tqdm_kwargs["disable"] = True
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
with logging.tqdm(**tqdm_kwargs) as pbar:
|
||||
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
|
||||
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
|
||||
@@ -59,8 +59,11 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||
from ..utils.hub_utils import (
|
||||
PushToHubMixin,
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
||||
from .model_loading_utils import (
|
||||
@@ -1669,10 +1672,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
shard_files = resolved_model_file
|
||||
if len(resolved_model_file) > 1:
|
||||
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
shard_tqdm_kwargs["disable"] = True
|
||||
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
|
||||
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
||||
|
||||
for shard_file in shard_files:
|
||||
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
|
||||
|
||||
@@ -67,7 +67,6 @@ from ..utils import (
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
|
||||
|
||||
@@ -983,11 +982,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
|
||||
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
|
||||
if not is_torch_dist_rank_zero():
|
||||
logging_tqdm_kwargs["disable"] = True
|
||||
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 7.1 device_map shenanigans
|
||||
if final_device_map is not None:
|
||||
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
|
||||
@@ -1913,14 +1908,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
progress_bar_config = dict(self._progress_bar_config)
|
||||
if "disable" not in progress_bar_config:
|
||||
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
|
||||
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **progress_bar_config)
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **progress_bar_config)
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
|
||||
def is_torch_dist_rank_zero() -> bool:
|
||||
if torch is None:
|
||||
return True
|
||||
|
||||
dist_module = getattr(torch, "distributed", None)
|
||||
if dist_module is None or not dist_module.is_available():
|
||||
return True
|
||||
|
||||
if not dist_module.is_initialized():
|
||||
return True
|
||||
|
||||
try:
|
||||
return dist_module.get_rank() == 0
|
||||
except (RuntimeError, ValueError):
|
||||
return True
|
||||
@@ -32,8 +32,6 @@ from typing import Dict, Optional
|
||||
|
||||
from tqdm import auto as tqdm_lib
|
||||
|
||||
from .distributed_utils import is_torch_dist_rank_zero
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_default_handler: Optional[logging.Handler] = None
|
||||
@@ -49,23 +47,6 @@ log_levels = {
|
||||
_default_log_level = logging.WARNING
|
||||
|
||||
_tqdm_active = True
|
||||
_rank_zero_filter = None
|
||||
|
||||
|
||||
class _RankZeroFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting.
|
||||
return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG
|
||||
|
||||
|
||||
def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
|
||||
global _rank_zero_filter
|
||||
|
||||
if _rank_zero_filter is None:
|
||||
_rank_zero_filter = _RankZeroFilter()
|
||||
|
||||
if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
|
||||
logger.addFilter(_rank_zero_filter)
|
||||
|
||||
|
||||
def _get_default_logging_level() -> int:
|
||||
@@ -109,7 +90,6 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.addHandler(_default_handler)
|
||||
library_root_logger.setLevel(_get_default_logging_level())
|
||||
library_root_logger.propagate = False
|
||||
_ensure_rank_zero_filter(library_root_logger)
|
||||
|
||||
|
||||
def _reset_library_root_logger() -> None:
|
||||
@@ -140,9 +120,7 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
name = _get_library_name()
|
||||
|
||||
_configure_library_root_logger()
|
||||
logger = logging.getLogger(name)
|
||||
_ensure_rank_zero_filter(logger)
|
||||
return logger
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def get_verbosity() -> int:
|
||||
|
||||
@@ -13,16 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
AuraFlowTransformer2DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
|
||||
from ..testing_utils import (
|
||||
floats_tensor,
|
||||
@@ -40,7 +36,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestAuraFlowLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = AuraFlowPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -103,34 +99,34 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,10 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
@@ -39,7 +38,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestCogVideoXLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
scheduler_cls = CogVideoXDPMScheduler
|
||||
scheduler_kwargs = {"timestep_spacing": "trailing"}
|
||||
@@ -119,54 +118,59 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3, pipe=pipe)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
|
||||
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
|
||||
super().test_lora_scale_kwargs_match_fusion(
|
||||
base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@pytest.mark.parametrize(
|
||||
"offload_type, use_stream",
|
||||
[("block_level", True), ("leaf_level", False)],
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -13,12 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
@@ -28,7 +25,6 @@ from ..testing_utils import (
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,7 +43,7 @@ class TokenizerWrapper:
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestCogView4LoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogView4Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -113,72 +109,50 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@pytest.mark.parametrize(
|
||||
"offload_type, use_stream",
|
||||
[("block_level", True), ("leaf_level", False)],
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -16,13 +16,11 @@ import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -46,14 +44,12 @@ from ..testing_utils import (
|
||||
|
||||
if is_peft_available():
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = FluxPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -115,165 +111,134 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_with_alpha_in_state_dict(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_with_alpha_in_state_dict(self, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
# only do for `transformer` and for the k projections -- should be enough to test.
|
||||
if "transformer" in k and "to_k" in k and "lora_A" in k:
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
if "transformer" in k and "to_k" in k and ("lora_A" in k):
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(state_dict_with_alpha)
|
||||
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
|
||||
assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001)
|
||||
|
||||
def test_lora_expansion_works_for_absent_keys(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
|
||||
# Modify the config to have a layer which won't be present in the second LoRA we will load.
|
||||
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
|
||||
modified_denoiser_lora_config.target_modules.add("x_embedder")
|
||||
|
||||
pipe.transformer.add_adapter(modified_denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
|
||||
# Modify the state dict to exclude "x_embedder" related LoRA params.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
|
||||
pipe.set_adapters(["one", "two"])
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
|
||||
"Different LoRAs should lead to different results.",
|
||||
assert not np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
|
||||
"Different LoRAs should lead to different results."
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
assert not np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
)
|
||||
|
||||
def test_lora_expansion_works_for_extra_keys(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
|
||||
# Modify the config to have a layer which won't be present in the first LoRA we will load.
|
||||
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
|
||||
modified_denoiser_lora_config.target_modules.add("x_embedder")
|
||||
|
||||
pipe.transformer.add_adapter(modified_denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
# Modify the state dict to exclude "x_embedder" related LoRA params.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
|
||||
# Load state dict with `x_embedder`.
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
pipe.unload_lora_weights()
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
|
||||
pipe.set_adapters(["one", "two"])
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
|
||||
"Different LoRAs should lead to different results.",
|
||||
assert not np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
|
||||
"Different LoRAs should lead to different results."
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
assert not np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
)
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
|
||||
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = FluxControlPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -338,12 +303,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_with_norm_in_state_dict(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def test_with_norm_in_state_dict(self, pipe):
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
@@ -364,39 +324,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe.load_lora_weights(norm_state_dict)
|
||||
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
assert (
|
||||
"The provided state dict contains normalization layers in addition to LoRA layers"
|
||||
in cap_logger.out
|
||||
)
|
||||
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
|
||||
assert len(pipe.transformer._transformer_norm_layers) > 0
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(pipe.transformer._transformer_norm_layers is None)
|
||||
self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5))
|
||||
self.assertFalse(
|
||||
np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested"
|
||||
assert pipe.transformer._transformer_norm_layers is None
|
||||
assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05)
|
||||
assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), (
|
||||
f"{norm_layer} is tested"
|
||||
)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
for key in list(norm_state_dict.keys()):
|
||||
norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key)
|
||||
pipe.load_lora_weights(norm_state_dict)
|
||||
assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
|
||||
|
||||
self.assertTrue(
|
||||
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
|
||||
)
|
||||
|
||||
def test_lora_parameter_expanded_shapes(self):
|
||||
def test_lora_parameter_expanded_shapes(self, pipe):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
@@ -405,24 +358,21 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
original_transformer_state_dict = pipe.transformer.state_dict()
|
||||
x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
|
||||
incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
|
||||
self.assertTrue(
|
||||
"x_embedder.weight" in incompatible_keys.missing_keys,
|
||||
"Could not find x_embedder.weight in the missing keys.",
|
||||
assert "x_embedder.weight" in incompatible_keys.missing_keys, (
|
||||
"Could not find x_embedder.weight in the missing keys."
|
||||
)
|
||||
|
||||
transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
|
||||
pipe.transformer = transformer
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
@@ -431,15 +381,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
# Testing opposite direction where the LoRA params are zero-padded.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -454,15 +402,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
|
||||
def test_normal_lora_with_expanded_lora_raises_error(self):
|
||||
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
|
||||
@@ -494,32 +440,28 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert pipe.get_active_adapters() == ["adapter-1"]
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
(_, _, inputs) = self.get_dummy_inputs(with_generator=False)
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
assert pipe.get_active_adapters() == ["adapter-2"]
|
||||
|
||||
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
|
||||
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
|
||||
|
||||
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
|
||||
# This should raise a runtime error on input shapes being incompatible.
|
||||
@@ -540,32 +482,24 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
|
||||
assert pipe.transformer.config.in_channels == in_features
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
|
||||
# We should check for input shapes being incompatible here. But because above mentioned issue is
|
||||
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
|
||||
# mismatch error.
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size mismatch for x_embedder.lora_A.adapter-2.weight",
|
||||
pipe.load_lora_weights,
|
||||
lora_state_dict,
|
||||
"adapter-2",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"):
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
|
||||
def test_fuse_expanded_lora_with_regular_lora(self):
|
||||
# This test checks if it works when a lora with expanded shapes (like control loras) but
|
||||
@@ -597,7 +531,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -610,54 +544,44 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
|
||||
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
|
||||
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
|
||||
assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001)
|
||||
assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001)
|
||||
|
||||
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
|
||||
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
|
||||
assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001)
|
||||
|
||||
def test_load_regular_lora(self):
|
||||
def test_load_regular_lora(self, base_pipe_output, pipe):
|
||||
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
|
||||
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
|
||||
# transformers include Flux Fill, Flux Control, etc.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
|
||||
in_features = in_features // 2
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
|
||||
assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001)
|
||||
|
||||
def test_lora_unload_with_parameter_expanded_shapes(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -670,9 +594,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
|
||||
@@ -697,33 +620,31 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
inputs["control_image"] = control_image
|
||||
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
|
||||
self.assertTrue(
|
||||
control_pipe.transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
|
||||
assert control_pipe.transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
|
||||
self.assertTrue(
|
||||
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
|
||||
assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
inputs.pop("control_image")
|
||||
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert np.allclose(unloaded_lora_out, original_out, atol=0.0001, rtol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
|
||||
assert pipe.transformer.config.in_channels == in_features
|
||||
|
||||
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -731,14 +652,12 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
|
||||
@@ -763,40 +682,38 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
inputs["control_image"] = control_image
|
||||
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
|
||||
self.assertTrue(
|
||||
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
|
||||
assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
|
||||
assert pipe.transformer.config.in_channels == in_features * 2
|
||||
|
||||
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -806,7 +723,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
class TestFluxLoRAIntegration:
|
||||
"""internal note: The integration slices were obtained on audace.
|
||||
|
||||
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
|
||||
@@ -816,33 +733,27 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
del self.pipeline
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_flux_the_last_ben(self):
|
||||
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
|
||||
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
|
||||
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
def test_flux_the_last_ben(self, pipeline):
|
||||
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "jon snow eating pizza with ketchup"
|
||||
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.0,
|
||||
@@ -851,71 +762,57 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya(self):
|
||||
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
def test_flux_kohya(self, pipeline):
|
||||
pipeline.load_lora_weights("Norod78/brain-slug-flux")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "The cat with a brain slug earring"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_with_text_encoder(self):
|
||||
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
def test_flux_kohya_with_text_encoder(self, pipeline):
|
||||
pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "optimus is cleaning the house with broomstick"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_embedders_conversion(self):
|
||||
def test_flux_kohya_embedders_conversion(self, pipeline):
|
||||
"""Test that embedders load without throwing errors"""
|
||||
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
assert True
|
||||
|
||||
def test_flux_xlabs(self):
|
||||
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
def test_flux_xlabs(self, pipeline):
|
||||
pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
|
||||
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
@@ -923,23 +820,17 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])
|
||||
|
||||
expected_slice = np.array([0.3965, 0.418, 0.4434, 0.4082, 0.4375, 0.459, 0.4141, 0.4375, 0.498])
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_xlabs_load_lora_with_single_blocks(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
|
||||
)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline):
|
||||
pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
prompt = "a wizard mouse playing chess"
|
||||
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
@@ -951,40 +842,43 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
[0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
assert max_diff < 0.001
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
class TestFluxControlLoRAIntegration:
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds."
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
self.pipeline = FluxControlPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
|
||||
def test_lora(self, lora_ckpt_id):
|
||||
self.pipeline.load_lora_weights(lora_ckpt_id)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
@pytest.mark.parametrize(
|
||||
"lora_ckpt_id",
|
||||
[
|
||||
"black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
"black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
],
|
||||
)
|
||||
def test_lora(self, pipeline, lora_ckpt_id):
|
||||
pipeline.load_lora_weights(lora_ckpt_id)
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
@@ -995,7 +889,7 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
|
||||
)
|
||||
|
||||
image = self.pipeline(
|
||||
image = pipeline(
|
||||
prompt=self.prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
@@ -1016,12 +910,18 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
|
||||
def test_lora_with_turbo(self, lora_ckpt_id):
|
||||
self.pipeline.load_lora_weights(lora_ckpt_id)
|
||||
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
@pytest.mark.parametrize(
|
||||
"lora_ckpt_id",
|
||||
[
|
||||
"black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
"black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
],
|
||||
)
|
||||
def test_lora_with_turbo(self, pipeline, lora_ckpt_id):
|
||||
pipeline.load_lora_weights(lora_ckpt_id)
|
||||
pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
@@ -30,7 +30,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestFlux2LoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = Flux2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -133,36 +133,36 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
assert np.isnan(out).all()
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
@pytest.mark.skip("Not supported in Flux2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
@pytest.mark.skip("Not supported in Flux2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
@pytest.mark.skip("Not supported in Flux2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
@@ -48,7 +48,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -149,46 +149,41 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
# TODO(aryan): Fix the following test
|
||||
@unittest.skip("This test fails with an error I haven't been able to debug yet.")
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -197,7 +192,7 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
class TestHunyuanVideoLoRAIntegration:
|
||||
"""internal note: The integration slices were obtained on DGX.
|
||||
|
||||
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
|
||||
@@ -207,9 +202,8 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -217,27 +211,27 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_original_format_cseti(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
def test_original_format_cseti(self, pipeline):
|
||||
pipeline.load_lora_weights(
|
||||
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
|
||||
)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.vae.enable_tiling()
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline.vae.enable_tiling()
|
||||
|
||||
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
|
||||
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt=prompt,
|
||||
height=320,
|
||||
width=512,
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestLTXVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = LTXPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -108,40 +108,40 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -36,7 +35,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestLumina2LoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = Lumina2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -101,35 +100,35 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -139,20 +138,17 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_lora_fuse_nan(self, pipe):
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
@@ -166,4 +162,4 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
assert np.isnan(out).all()
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestMochiLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = MochiPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -99,44 +99,44 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestQwenImageLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = QwenImagePipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -96,34 +96,34 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import Gemma2Model, GemmaTokenizer
|
||||
|
||||
@@ -29,7 +29,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestSanaLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = SanaPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {"shift": 7.0}
|
||||
@@ -105,38 +105,38 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
return super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -55,7 +55,7 @@ if is_accelerate_available():
|
||||
from accelerate.utils import release_memory
|
||||
|
||||
|
||||
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
scheduler_cls = DDIMScheduler
|
||||
scheduler_kwargs = {
|
||||
@@ -91,16 +91,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
def output_shape(self):
|
||||
return (1, 64, 64, 3)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Keeping this test here makes sense because it doesn't look any integration
|
||||
# (value assertions on logits).
|
||||
@slow
|
||||
@@ -114,15 +104,8 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id, adapter_name="adapter-2")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
|
||||
# We will offload the first adapter in CPU and check if the offloading
|
||||
# has been performed correctly
|
||||
@@ -130,35 +113,35 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
|
||||
for name, module in pipe.text_encoder.named_modules():
|
||||
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
|
||||
pipe.set_lora_device(["adapter-1"], 0)
|
||||
|
||||
for n, m in pipe.unet.named_modules():
|
||||
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
|
||||
for n, m in pipe.text_encoder.named_modules():
|
||||
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
|
||||
pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device)
|
||||
|
||||
for n, m in pipe.unet.named_modules():
|
||||
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
|
||||
for n, m in pipe.text_encoder.named_modules():
|
||||
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@@ -181,15 +164,9 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
|
||||
for name, param in pipe.unet.named_parameters():
|
||||
if "lora_" in name:
|
||||
@@ -225,17 +202,14 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
|
||||
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
|
||||
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
|
||||
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
|
||||
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
|
||||
self.assertTrue(modules_adapter_0 - modules_adapter_1)
|
||||
self.assertTrue(modules_adapter_1 - modules_adapter_0)
|
||||
assert modules_adapter_0 != modules_adapter_1
|
||||
assert modules_adapter_0 - modules_adapter_1
|
||||
assert modules_adapter_1 - modules_adapter_0
|
||||
|
||||
# setting both separately works
|
||||
pipe.set_lora_device(["adapter-0"], "cpu")
|
||||
@@ -243,32 +217,30 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
|
||||
# setting both at once also works
|
||||
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
class LoraIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
class TestSDLoraIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
yield
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -280,10 +252,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
|
||||
prompt = "a red sks dog"
|
||||
|
||||
@@ -312,10 +281,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
|
||||
prompt = "a red sks dog"
|
||||
|
||||
@@ -587,8 +553,8 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertFalse(np.allclose(initial_images, lora_images))
|
||||
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
|
||||
assert not np.allclose(initial_images, lora_images)
|
||||
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
@@ -625,8 +591,8 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertFalse(np.allclose(initial_images, lora_images))
|
||||
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
|
||||
assert not np.allclose(initial_images, lora_images)
|
||||
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
|
||||
|
||||
# make sure we can load a LoRA again after unloading and they don't have
|
||||
# any undesired effects.
|
||||
@@ -637,7 +603,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
|
||||
assert np.allclose(lora_images, lora_images_again, atol=1e-3)
|
||||
release_memory(pipe)
|
||||
|
||||
def test_not_empty_state_dict(self):
|
||||
@@ -651,7 +617,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
lcm_lora = load_file(cached_file)
|
||||
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertTrue(lcm_lora != {})
|
||||
assert lcm_lora != {}
|
||||
release_memory(pipe)
|
||||
|
||||
def test_load_unload_load_state_dict(self):
|
||||
@@ -666,11 +632,11 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
previous_state_dict = lcm_lora.copy()
|
||||
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertDictEqual(lcm_lora, previous_state_dict)
|
||||
assert lcm_lora == previous_state_dict
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertDictEqual(lcm_lora, previous_state_dict)
|
||||
assert lcm_lora == previous_state_dict
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_accelerate_available():
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestSD3LoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -113,19 +113,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
lora_filename = "lora_peft_format.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@@ -138,17 +138,15 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class SD3LoraIntegrationTests(unittest.TestCase):
|
||||
class TestSD3LoraIntegration:
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
yield
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import gc
|
||||
import importlib
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
@@ -59,7 +59,7 @@ if is_accelerate_available():
|
||||
from accelerate.utils import release_memory
|
||||
|
||||
|
||||
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
|
||||
has_two_text_encoders = True
|
||||
pipeline_class = StableDiffusionXLPipeline
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
@@ -104,21 +104,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
def output_shape(self):
|
||||
return (1, 64, 64, 3)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@is_flaky
|
||||
def test_multiple_wrong_adapter_name_raises_error(self):
|
||||
super().test_multiple_wrong_adapter_name_raises_error()
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -127,10 +117,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(
|
||||
expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -139,10 +129,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(
|
||||
expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -150,21 +140,21 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
expected_atol = 1e-3
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
|
||||
super().test_lora_scale_kwargs_match_fusion(
|
||||
base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
class TestLoraSDXLIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
yield
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -383,7 +373,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
end_time = time.time()
|
||||
elapsed_time_fusion = end_time - start_time
|
||||
|
||||
self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion)
|
||||
assert elapsed_time_fusion < elapsed_time_non_fusion
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
@@ -439,14 +429,14 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
|
||||
for key, value in text_encoder_1_sd.items():
|
||||
key = remap_key(key, fused_te_state_dict)
|
||||
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
|
||||
assert torch.allclose(fused_te_state_dict[key], value)
|
||||
|
||||
for key, value in text_encoder_2_sd.items():
|
||||
key = remap_key(key, fused_te_2_state_dict)
|
||||
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
|
||||
assert torch.allclose(fused_te_2_state_dict[key], value)
|
||||
|
||||
for key, value in unet_state_dict.items():
|
||||
self.assertTrue(torch.allclose(unet_state_dict[key], value))
|
||||
assert torch.allclose(unet_state_dict[key], value)
|
||||
|
||||
pipe.fuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
@@ -589,7 +579,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet"
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -39,7 +39,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestWanLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -104,40 +104,40 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -32,7 +31,6 @@ from ..testing_utils import (
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,7 +45,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
@is_flaky(max_attempts=10, description="very flaky class")
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestWanVACELoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -121,56 +119,51 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules_wanvace(self):
|
||||
def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname, pipe):
|
||||
exclude_module_name = "vace_blocks.0.proj_out"
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
assert base_pipe_output.shape == self.output_shape
|
||||
|
||||
# only supported for `denoiser` now
|
||||
denoiser_lora_config.target_modules = ["proj_out"]
|
||||
@@ -180,36 +173,30 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
)
|
||||
# The state dict shouldn't contain the modules to be excluded from LoRA.
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
|
||||
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
|
||||
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
|
||||
assert not any(exclude_module_name in k for k in state_dict_from_model)
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
# Check in the loaded state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
|
||||
self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
|
||||
# Check in the loaded state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
assert not any(exclude_module_name in k for k in loaded_state_dict)
|
||||
assert any("proj_out" in k for k in loaded_state_dict)
|
||||
|
||||
# Check in the state dict obtained after loading LoRA.
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
|
||||
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
|
||||
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
|
||||
# Check in the state dict obtained after loading LoRA.
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
|
||||
assert not any(exclude_module_name in k for k in state_dict_from_model)
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_and_scale()
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
|
||||
"LoRA should change outputs."
|
||||
)
|
||||
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
|
||||
"Lora outputs should match."
|
||||
)
|
||||
|
||||
1658
tests/lora/utils.py
1658
tests/lora/utils.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user