mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
8 Commits
ltx-098-me
...
remove-uni
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
361ca4cc95 | ||
|
|
e4a1c036f1 | ||
|
|
89e3563ff5 | ||
|
|
208af22812 | ||
|
|
f506704717 | ||
|
|
8f63b4b8c8 | ||
|
|
178d35e2cb | ||
|
|
a7654b8c67 |
@@ -22,12 +22,11 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -210,16 +209,18 @@ def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype):
|
||||
return maybe_tensor
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
class TestModelUtils:
|
||||
def teardown_method(self):
|
||||
pass
|
||||
|
||||
def test_missing_key_loading_warning_message(self):
|
||||
with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
|
||||
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
|
||||
def test_missing_key_loading_warning_message(self, caplog):
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="diffusers.models.modeling_utils")
|
||||
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "conv_out.bias" in " ".join(logs.output)
|
||||
assert "conv_out.bias" in caplog.text
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
@@ -236,7 +237,7 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
kwargs["subfolder"] = subfolder
|
||||
return UNet2DConditionModel.from_pretrained(path, **kwargs)
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
with pytest.warns(FutureWarning) as warning:
|
||||
if use_local:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = snapshot_download(repo_id=repo_id)
|
||||
@@ -244,8 +245,8 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
else:
|
||||
_ = load_model(repo_id)
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message)
|
||||
warning_message = str(warning.list[0].message)
|
||||
assert "This serialization format is now deprecated to standardize the serialization" in warning_message
|
||||
|
||||
# Local tests are already covered down below.
|
||||
@parameterized.expand(
|
||||
@@ -306,7 +307,7 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
with mock.patch("requests.Session.get", return_value=error_response):
|
||||
# Should fail with local_files_only=False (network required)
|
||||
# We would make a network call with model_info
|
||||
with self.assertRaises(OSError):
|
||||
with pytest.raises(OSError):
|
||||
FluxTransformer2DModel.from_pretrained(
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
|
||||
)
|
||||
@@ -328,7 +329,7 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
os.remove(cached_shard_file)
|
||||
|
||||
# Attempting to load from cache should raise an error
|
||||
with self.assertRaises(OSError) as context:
|
||||
with pytest.raises(OSError) as context:
|
||||
FluxTransformer2DModel.from_pretrained(
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
|
||||
)
|
||||
@@ -339,8 +340,8 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
f"Expected error about missing shard, got: {error_msg}"
|
||||
)
|
||||
|
||||
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
|
||||
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
|
||||
@pytest.mark.skip(reason="Flaky behaviour on CI. Re-enable after migrating to new runners")
|
||||
@pytest.mark.skipif(torch_device == "mps", reason="Test not supported for MPS.")
|
||||
def test_one_request_upon_cached(self):
|
||||
use_safetensors = False
|
||||
|
||||
@@ -373,7 +374,7 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_weight_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(ValueError) as error_context:
|
||||
UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="unet",
|
||||
@@ -414,9 +415,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if name in model._keep_in_fp32_modules:
|
||||
self.assertTrue(module.weight.dtype == torch.float32)
|
||||
assert module.weight.dtype == torch.float32
|
||||
else:
|
||||
self.assertTrue(module.weight.dtype == torch_dtype)
|
||||
assert module.weight.dtype == torch_dtype
|
||||
|
||||
def get_dummy_inputs():
|
||||
batch_size = 2
|
||||
@@ -466,9 +467,9 @@ class UNetTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
@@ -478,6 +479,18 @@ class ModelTesterMixin:
|
||||
model_split_percents = [0.5, 0.7, 0.9]
|
||||
uses_custom_attn_processor = False
|
||||
|
||||
def get_init_dict(self):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_init_dict(self)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def check_device_map_is_respected(self, model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
@@ -488,9 +501,9 @@ class ModelTesterMixin:
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
self.assertEqual(param.device, torch.device("meta"))
|
||||
assert param.device == torch.device("meta")
|
||||
else:
|
||||
self.assertEqual(param.device, torch.device(param_device))
|
||||
assert param.device == torch.device(param_device)
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
@@ -529,7 +542,7 @@ class ModelTesterMixin:
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
|
||||
assert max_diff <= expected_max_diff, "Models give different forward passes"
|
||||
|
||||
def test_getattr_is_correct(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -563,18 +576,18 @@ class ModelTesterMixin:
|
||||
assert cap_logger.out == ""
|
||||
|
||||
# warning should be thrown
|
||||
with self.assertWarns(FutureWarning):
|
||||
with pytest.warns(FutureWarning):
|
||||
assert model.test_attribute == 5
|
||||
|
||||
with self.assertWarns(FutureWarning):
|
||||
with pytest.warns(FutureWarning):
|
||||
assert getattr(model, "test_attribute") == 5
|
||||
|
||||
with self.assertRaises(AttributeError) as error:
|
||||
with pytest.raises(AttributeError) as error:
|
||||
model.does_not_exist
|
||||
|
||||
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
@unittest.skipIf(
|
||||
@pytest.mark.skipif(
|
||||
torch_device != "npu" or not is_torch_npu_available(),
|
||||
reason="torch npu flash attention is only available with NPU and `torch_npu` installed",
|
||||
)
|
||||
@@ -621,7 +634,7 @@ class ModelTesterMixin:
|
||||
assert torch.allclose(output, output_3, atol=self.base_precision)
|
||||
assert torch.allclose(output_2, output_3, atol=self.base_precision)
|
||||
|
||||
@unittest.skipIf(
|
||||
@pytest.mark.skipif(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
@@ -748,7 +761,7 @@ class ModelTesterMixin:
|
||||
new_model.set_default_attn_processor()
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with self.assertRaises(OSError) as error_context:
|
||||
with pytest.raises(OSError) as error_context:
|
||||
self.model_class.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
@@ -773,11 +786,11 @@ class ModelTesterMixin:
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
|
||||
assert max_diff <= expected_max_diff, "Models give different forward passes"
|
||||
|
||||
@is_torch_compile
|
||||
@require_torch_2
|
||||
@unittest.skipIf(
|
||||
@pytest.mark.skipif(
|
||||
get_python_version == (3, 12),
|
||||
reason="Torch Dynamo isn't yet supported for Python 3.12.",
|
||||
)
|
||||
@@ -839,7 +852,7 @@ class ModelTesterMixin:
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, expected_max_diff)
|
||||
assert max_diff <= expected_max_diff
|
||||
|
||||
def test_output(self, expected_output_shape=None):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -853,16 +866,16 @@ class ModelTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
|
||||
if expected_output_shape is None:
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
else:
|
||||
self.assertEqual(output.shape, expected_output_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_output_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -883,7 +896,7 @@ class ModelTesterMixin:
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
self.assertEqual(param_1.shape, param_2.shape)
|
||||
assert param_1.shape == param_2.shape
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict)
|
||||
@@ -896,7 +909,7 @@ class ModelTesterMixin:
|
||||
if isinstance(output_2, dict):
|
||||
output_2 = output_2.to_tuple()[0]
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
assert output_1.shape == output_2.shape
|
||||
|
||||
@require_torch_accelerator_with_training
|
||||
def test_training(self):
|
||||
@@ -955,16 +968,13 @@ class ModelTesterMixin:
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
assert torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
), (
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
)
|
||||
|
||||
if self.forward_requires_fresh_args:
|
||||
@@ -996,15 +1006,15 @@ class ModelTesterMixin:
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
assert not model.is_gradient_checkpointing
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
self.assertTrue(model.is_gradient_checkpointing)
|
||||
assert model.is_gradient_checkpointing
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
assert not model.is_gradient_checkpointing
|
||||
|
||||
@require_torch_accelerator_with_training
|
||||
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
|
||||
@@ -1048,7 +1058,7 @@ class ModelTesterMixin:
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
|
||||
assert (loss - loss_2).abs() < loss_tolerance
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
@@ -1061,9 +1071,9 @@ class ModelTesterMixin:
|
||||
# It currently errors out the gradient checkpointing test because the gradients for attn2.to_out is None
|
||||
if param.grad is None:
|
||||
continue
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
|
||||
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
|
||||
@pytest.mark.skipif(torch_device == "mps", reason="This test is not supported for MPS devices.")
|
||||
def test_gradient_checkpointing_is_applied(
|
||||
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
|
||||
):
|
||||
@@ -1087,7 +1097,7 @@ class ModelTesterMixin:
|
||||
modules_with_gc_enabled = {}
|
||||
for submodule in model.modules():
|
||||
if hasattr(submodule, "gradient_checkpointing"):
|
||||
self.assertTrue(submodule.gradient_checkpointing)
|
||||
assert submodule.gradient_checkpointing
|
||||
modules_with_gc_enabled[submodule.__class__.__name__] = True
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set
|
||||
@@ -1115,7 +1125,7 @@ class ModelTesterMixin:
|
||||
|
||||
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
@pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT")
|
||||
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
@@ -1139,21 +1149,21 @@ class ModelTesterMixin:
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
@@ -1161,17 +1171,17 @@ class ModelTesterMixin:
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
self.assertTrue(torch.allclose(loaded_v, retrieved_v))
|
||||
assert torch.allclose(loaded_v, retrieved_v)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
@pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT")
|
||||
def test_lora_wrong_adapter_name_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
@@ -1191,18 +1201,18 @@ class ModelTesterMixin:
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wrong_name = "foo"
|
||||
with self.assertRaises(ValueError) as err_context:
|
||||
with pytest.raises(ValueError) as err_context:
|
||||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
|
||||
|
||||
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)
|
||||
|
||||
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
@pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT")
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
|
||||
from peft import LoraConfig
|
||||
|
||||
@@ -1223,22 +1233,22 @@ class ModelTesterMixin:
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(model_file))
|
||||
assert os.path.isfile(model_file)
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
@pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT")
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
@@ -1259,12 +1269,12 @@ class ModelTesterMixin:
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(model_file))
|
||||
assert os.path.isfile(model_file)
|
||||
|
||||
# Perturb the metadata in the state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
@@ -1278,11 +1288,11 @@ class ModelTesterMixin:
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with self.assertRaises(TypeError) as err_context:
|
||||
with pytest.raises(TypeError) as err_context:
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
|
||||
assert "`LoraConfig` class could not be instantiated" in str(err_context.exception)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_cpu_offload(self):
|
||||
@@ -1306,13 +1316,13 @@ class ModelTesterMixin:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
@@ -1333,7 +1343,7 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
# This errors out because it's missing an offload folder
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
@@ -1345,7 +1355,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
@@ -1373,7 +1383,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
@@ -1397,14 +1407,14 @@ class ModelTesterMixin:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_sharded_checkpoints(self):
|
||||
@@ -1419,14 +1429,14 @@ class ModelTesterMixin:
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
assert actual_num_shards == expected_num_shards
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
@@ -1436,7 +1446,7 @@ class ModelTesterMixin:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_sharded_checkpoints_with_variant(self):
|
||||
@@ -1457,14 +1467,14 @@ class ModelTesterMixin:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename)))
|
||||
assert os.path.exists(os.path.join(tmp_dir, index_filename))
|
||||
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
assert actual_num_shards == expected_num_shards
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
@@ -1474,7 +1484,7 @@ class ModelTesterMixin:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_sharded_checkpoints_with_parallel_loading(self):
|
||||
@@ -1489,14 +1499,14 @@ class ModelTesterMixin:
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
assert actual_num_shards == expected_num_shards
|
||||
|
||||
# Load with parallel loading
|
||||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes"
|
||||
@@ -1507,7 +1517,7 @@ class ModelTesterMixin:
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
|
||||
# set to no.
|
||||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no"
|
||||
|
||||
@@ -1526,14 +1536,14 @@ class ModelTesterMixin:
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
assert actual_num_shards == expected_num_shards
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
|
||||
|
||||
@@ -1541,7 +1551,7 @@ class ModelTesterMixin:
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
|
||||
|
||||
# This test is okay without a GPU because we're not running any execution. We're just serializing
|
||||
# and check if the resultant files are following an expected format.
|
||||
@@ -1559,14 +1569,14 @@ class ModelTesterMixin:
|
||||
tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe
|
||||
)
|
||||
index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant)))
|
||||
assert os.path.exists(os.path.join(tmp_dir, index_variant))
|
||||
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
assert actual_num_shards == expected_num_shards
|
||||
|
||||
# Check if the variant is present as a substring in the checkpoints.
|
||||
shard_files = [
|
||||
@@ -1634,9 +1644,9 @@ class ModelTesterMixin:
|
||||
if any(re.search(pattern, name) for pattern in patterns_to_check):
|
||||
dtype_to_check = compute_dtype
|
||||
if getattr(submodule, "weight", None) is not None:
|
||||
self.assertEqual(submodule.weight.dtype, dtype_to_check)
|
||||
assert submodule.weight.dtype == dtype_to_check
|
||||
if getattr(submodule, "bias", None) is not None:
|
||||
self.assertEqual(submodule.bias.dtype, dtype_to_check)
|
||||
assert submodule.bias.dtype == dtype_to_check
|
||||
|
||||
def test_layerwise_casting(storage_dtype, compute_dtype):
|
||||
torch.manual_seed(0)
|
||||
@@ -1651,7 +1661,7 @@ class ModelTesterMixin:
|
||||
|
||||
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
|
||||
# We just want to make sure that the layerwise casting is working as expected.
|
||||
self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0)
|
||||
assert numpy_cosine_similarity_distance(base_slice, output) < 1.0
|
||||
|
||||
test_layerwise_casting(torch.float16, torch.float32)
|
||||
test_layerwise_casting(torch.float8_e4m3fn, torch.float32)
|
||||
@@ -1692,15 +1702,15 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
|
||||
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
|
||||
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint
|
||||
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
|
||||
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
|
||||
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
|
||||
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
|
||||
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory
|
||||
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
||||
# bytes. This only happens for some models, so we allow a small tolerance.
|
||||
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
||||
self.assertTrue(
|
||||
assert (
|
||||
fp8_e4m3_fp32_max_memory < fp32_max_memory
|
||||
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||
)
|
||||
@@ -1716,12 +1726,10 @@ class ModelTesterMixin:
|
||||
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
self.assertTrue(
|
||||
all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
)
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
)
|
||||
model.eval()
|
||||
return model(**inputs_dict)[0]
|
||||
@@ -1753,10 +1761,10 @@ class ModelTesterMixin:
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)
|
||||
|
||||
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
|
||||
@require_torch_accelerator
|
||||
@@ -1838,7 +1846,7 @@ class ModelTesterMixin:
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
self.assertTrue(has_safetensors, "No safetensors found in the directory.")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
@@ -1856,7 +1864,7 @@ class ModelTesterMixin:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol)
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
@@ -1895,10 +1903,8 @@ class ModelTesterMixin:
|
||||
output_auto = output_auto.to_tuple()[0]
|
||||
|
||||
max_diff = (output_original - output_auto).abs().max().item()
|
||||
self.assertLessEqual(
|
||||
max_diff,
|
||||
expected_max_diff,
|
||||
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}",
|
||||
assert max_diff <= expected_max_diff, (
|
||||
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}"
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
@@ -1912,7 +1918,7 @@ class ModelTesterMixin:
|
||||
model = self.model_class(**init_dict)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_pretrained(tmpdir)
|
||||
with self.assertRaises(ValueError) as err_ctx:
|
||||
with pytest.raises(ValueError) as err_ctx:
|
||||
_ = self.model_class.from_pretrained(tmpdir, device_map=device_map)
|
||||
|
||||
assert msg_substring in str(err_ctx.exception)
|
||||
@@ -1942,7 +1948,7 @@ class ModelTesterMixin:
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
class TestModelPushToHub:
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-model-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
@@ -1962,7 +1968,7 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
assert torch.equal(p1, p2)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
@@ -1973,7 +1979,7 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
assert torch.equal(p1, p2)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
@@ -1993,7 +1999,7 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
assert torch.equal(p1, p2)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
@@ -2004,12 +2010,12 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
assert torch.equal(p1, p2)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
@unittest.skipIf(
|
||||
@pytest.mark.skipif(
|
||||
not is_jinja_available(),
|
||||
reason="Model card tests cannot be performed without Jinja installed.",
|
||||
)
|
||||
@@ -2296,7 +2302,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
@@ -2357,10 +2363,10 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
@@ -2371,9 +2377,11 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with self.assertLogs(logger=logger, level="WARNING") as cm:
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger=logger.name):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in log for log in cm.output)
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
|
||||
# check possibility to ignore the error/warning
|
||||
@@ -2384,7 +2392,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always") # Capture all warnings
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
|
||||
assert len(w) == 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}"
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
@@ -2393,22 +2401,24 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self):
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self, caplog):
|
||||
# check the error and log
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with self.assertRaises(RuntimeError): # peft raises RuntimeError
|
||||
with self.assertLogs(logger=logger, level="ERROR") as cm:
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.ERROR, logger=logger.name):
|
||||
self.check_model_hotswap(
|
||||
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
|
||||
Reference in New Issue
Block a user