Compare commits

...

8 Commits

Author SHA1 Message Date
DN6
361ca4cc95 Merge branch 'main' into remove-unittest-modelmixin 2025-09-24 09:16:37 +05:30
DN6
e4a1c036f1 update 2025-09-24 09:14:40 +05:30
DN6
89e3563ff5 update 2025-09-24 07:53:25 +05:30
DN6
208af22812 Merge branch 'main' into automodel-custom-model 2025-09-24 07:49:48 +05:30
DN6
f506704717 Merge branch 'main' into automodel-custom-model 2025-09-24 07:49:19 +05:30
DN6
8f63b4b8c8 update 2025-09-23 16:37:48 +05:30
DN6
178d35e2cb Merge branch 'main' into automodel-custom-model 2025-09-19 21:56:47 +05:30
DN6
a7654b8c67 update 2025-08-29 12:11:17 +05:30

View File

@@ -22,12 +22,11 @@ import os
import re
import tempfile
import traceback
import unittest
import unittest.mock as mock
import uuid
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
from unittest import mock
import numpy as np
import pytest
@@ -210,16 +209,18 @@ def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype):
return maybe_tensor
class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
class TestModelUtils:
def teardown_method(self):
pass
def test_missing_key_loading_warning_message(self):
with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
def test_missing_key_loading_warning_message(self, caplog):
import logging
caplog.set_level(logging.WARNING, logger="diffusers.models.modeling_utils")
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
# make sure that error message states what keys are missing
assert "conv_out.bias" in " ".join(logs.output)
assert "conv_out.bias" in caplog.text
@parameterized.expand(
[
@@ -236,7 +237,7 @@ class ModelUtilsTest(unittest.TestCase):
kwargs["subfolder"] = subfolder
return UNet2DConditionModel.from_pretrained(path, **kwargs)
with self.assertWarns(FutureWarning) as warning:
with pytest.warns(FutureWarning) as warning:
if use_local:
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = snapshot_download(repo_id=repo_id)
@@ -244,8 +245,8 @@ class ModelUtilsTest(unittest.TestCase):
else:
_ = load_model(repo_id)
warning_message = str(warning.warnings[0].message)
self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message)
warning_message = str(warning.list[0].message)
assert "This serialization format is now deprecated to standardize the serialization" in warning_message
# Local tests are already covered down below.
@parameterized.expand(
@@ -306,7 +307,7 @@ class ModelUtilsTest(unittest.TestCase):
with mock.patch("requests.Session.get", return_value=error_response):
# Should fail with local_files_only=False (network required)
# We would make a network call with model_info
with self.assertRaises(OSError):
with pytest.raises(OSError):
FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
)
@@ -328,7 +329,7 @@ class ModelUtilsTest(unittest.TestCase):
os.remove(cached_shard_file)
# Attempting to load from cache should raise an error
with self.assertRaises(OSError) as context:
with pytest.raises(OSError) as context:
FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)
@@ -339,8 +340,8 @@ class ModelUtilsTest(unittest.TestCase):
f"Expected error about missing shard, got: {error_msg}"
)
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
@pytest.mark.skip(reason="Flaky behaviour on CI. Re-enable after migrating to new runners")
@pytest.mark.skipif(torch_device == "mps", reason="Test not supported for MPS.")
def test_one_request_upon_cached(self):
use_safetensors = False
@@ -373,7 +374,7 @@ class ModelUtilsTest(unittest.TestCase):
)
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="unet",
@@ -414,9 +415,9 @@ class ModelUtilsTest(unittest.TestCase):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.float32)
assert module.weight.dtype == torch.float32
else:
self.assertTrue(module.weight.dtype == torch_dtype)
assert module.weight.dtype == torch_dtype
def get_dummy_inputs():
batch_size = 2
@@ -466,9 +467,9 @@ class UNetTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
class ModelTesterMixin:
@@ -478,6 +479,18 @@ class ModelTesterMixin:
model_split_percents = [0.5, 0.7, 0.9]
uses_custom_attn_processor = False
def get_init_dict(self):
raise NotImplementedError(
"You need to implement `get_init_dict(self)` in the child test class. "
"See existing pipeline tests for reference."
)
def get_dummy_inputs(self, device, seed=0):
raise NotImplementedError(
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
"See existing pipeline tests for reference."
)
def check_device_map_is_respected(self, model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
@@ -488,9 +501,9 @@ class ModelTesterMixin:
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
self.assertEqual(param.device, torch.device("meta"))
assert param.device == torch.device("meta")
else:
self.assertEqual(param.device, torch.device(param_device))
assert param.device == torch.device(param_device)
def test_from_save_pretrained(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
@@ -529,7 +542,7 @@ class ModelTesterMixin:
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
assert max_diff <= expected_max_diff, "Models give different forward passes"
def test_getattr_is_correct(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -563,18 +576,18 @@ class ModelTesterMixin:
assert cap_logger.out == ""
# warning should be thrown
with self.assertWarns(FutureWarning):
with pytest.warns(FutureWarning):
assert model.test_attribute == 5
with self.assertWarns(FutureWarning):
with pytest.warns(FutureWarning):
assert getattr(model, "test_attribute") == 5
with self.assertRaises(AttributeError) as error:
with pytest.raises(AttributeError) as error:
model.does_not_exist
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
@unittest.skipIf(
@pytest.mark.skipif(
torch_device != "npu" or not is_torch_npu_available(),
reason="torch npu flash attention is only available with NPU and `torch_npu` installed",
)
@@ -621,7 +634,7 @@ class ModelTesterMixin:
assert torch.allclose(output, output_3, atol=self.base_precision)
assert torch.allclose(output_2, output_3, atol=self.base_precision)
@unittest.skipIf(
@pytest.mark.skipif(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
@@ -748,7 +761,7 @@ class ModelTesterMixin:
new_model.set_default_attn_processor()
# non-variant cannot be loaded
with self.assertRaises(OSError) as error_context:
with pytest.raises(OSError) as error_context:
self.model_class.from_pretrained(tmpdirname)
# make sure that error message states what keys are missing
@@ -773,11 +786,11 @@ class ModelTesterMixin:
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
assert max_diff <= expected_max_diff, "Models give different forward passes"
@is_torch_compile
@require_torch_2
@unittest.skipIf(
@pytest.mark.skipif(
get_python_version == (3, 12),
reason="Torch Dynamo isn't yet supported for Python 3.12.",
)
@@ -839,7 +852,7 @@ class ModelTesterMixin:
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, expected_max_diff)
assert max_diff <= expected_max_diff
def test_output(self, expected_output_shape=None):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -853,16 +866,16 @@ class ModelTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
assert output is not None
# input & output have to have the same shape
input_tensor = inputs_dict[self.main_input_name]
if expected_output_shape is None:
expected_shape = input_tensor.shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
else:
self.assertEqual(output.shape, expected_output_shape, "Input and output shapes do not match")
assert output.shape == expected_output_shape, "Input and output shapes do not match"
def test_model_from_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -883,7 +896,7 @@ class ModelTesterMixin:
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
self.assertEqual(param_1.shape, param_2.shape)
assert param_1.shape == param_2.shape
with torch.no_grad():
output_1 = model(**inputs_dict)
@@ -896,7 +909,7 @@ class ModelTesterMixin:
if isinstance(output_2, dict):
output_2 = output_2.to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape)
assert output_1.shape == output_2.shape
@require_torch_accelerator_with_training
def test_training(self):
@@ -955,16 +968,13 @@ class ModelTesterMixin:
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
assert torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
), (
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
)
if self.forward_requires_fresh_args:
@@ -996,15 +1006,15 @@ class ModelTesterMixin:
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
self.assertFalse(model.is_gradient_checkpointing)
assert not model.is_gradient_checkpointing
# check enable works
model.enable_gradient_checkpointing()
self.assertTrue(model.is_gradient_checkpointing)
assert model.is_gradient_checkpointing
# check disable works
model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing)
assert not model.is_gradient_checkpointing
@require_torch_accelerator_with_training
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
@@ -1048,7 +1058,7 @@ class ModelTesterMixin:
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
assert (loss - loss_2).abs() < loss_tolerance
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
@@ -1061,9 +1071,9 @@ class ModelTesterMixin:
# It currently errors out the gradient checkpointing test because the gradients for attn2.to_out is None
if param.grad is None:
continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol)
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
@pytest.mark.skipif(torch_device == "mps", reason="This test is not supported for MPS devices.")
def test_gradient_checkpointing_is_applied(
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
):
@@ -1087,7 +1097,7 @@ class ModelTesterMixin:
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
self.assertTrue(submodule.gradient_checkpointing)
assert submodule.gradient_checkpointing
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set
@@ -1115,7 +1125,7 @@ class ModelTesterMixin:
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
@pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT")
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
@@ -1139,21 +1149,21 @@ class ModelTesterMixin:
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
@@ -1161,17 +1171,17 @@ class ModelTesterMixin:
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
self.assertTrue(torch.allclose(loaded_v, retrieved_v))
assert torch.allclose(loaded_v, retrieved_v)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
@pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT")
def test_lora_wrong_adapter_name_raises_error(self):
from peft import LoraConfig
@@ -1191,18 +1201,18 @@ class ModelTesterMixin:
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with self.assertRaises(ValueError) as err_context:
with pytest.raises(ValueError) as err_context:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
assert f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
@pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT")
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
from peft import LoraConfig
@@ -1223,22 +1233,22 @@ class ModelTesterMixin:
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
assert os.path.isfile(model_file)
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
@pytest.mark.skipif(not is_peft_available(), reason="Only with PEFT")
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
@@ -1259,12 +1269,12 @@ class ModelTesterMixin:
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
assert os.path.isfile(model_file)
# Perturb the metadata in the state dict.
loaded_state_dict = safetensors.torch.load_file(model_file)
@@ -1278,11 +1288,11 @@ class ModelTesterMixin:
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with self.assertRaises(TypeError) as err_context:
with pytest.raises(TypeError) as err_context:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
assert "`LoraConfig` class could not be instantiated" in str(err_context.exception)
@require_torch_accelerator
def test_cpu_offload(self):
@@ -1306,13 +1316,13 @@ class ModelTesterMixin:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
assert set(new_model.hf_device_map.values()) == {0, "cpu"}
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
@require_torch_accelerator
def test_disk_offload_without_safetensors(self):
@@ -1333,7 +1343,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
# This errors out because it's missing an offload folder
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
@@ -1345,7 +1355,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
@require_torch_accelerator
def test_disk_offload_with_safetensors(self):
@@ -1373,7 +1383,7 @@ class ModelTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
@require_torch_multi_accelerator
def test_model_parallelism(self):
@@ -1397,14 +1407,14 @@ class ModelTesterMixin:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
assert set(new_model.hf_device_map.values()) == {0, 1}
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
@require_torch_accelerator
def test_sharded_checkpoints(self):
@@ -1419,14 +1429,14 @@ class ModelTesterMixin:
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
assert actual_num_shards == expected_num_shards
new_model = self.model_class.from_pretrained(tmp_dir).eval()
new_model = new_model.to(torch_device)
@@ -1436,7 +1446,7 @@ class ModelTesterMixin:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
@require_torch_accelerator
def test_sharded_checkpoints_with_variant(self):
@@ -1457,14 +1467,14 @@ class ModelTesterMixin:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename)))
assert os.path.exists(os.path.join(tmp_dir, index_filename))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
assert actual_num_shards == expected_num_shards
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
new_model = new_model.to(torch_device)
@@ -1474,7 +1484,7 @@ class ModelTesterMixin:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
@require_torch_accelerator
def test_sharded_checkpoints_with_parallel_loading(self):
@@ -1489,14 +1499,14 @@ class ModelTesterMixin:
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
assert actual_num_shards == expected_num_shards
# Load with parallel loading
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes"
@@ -1507,7 +1517,7 @@ class ModelTesterMixin:
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
# set to no.
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no"
@@ -1526,14 +1536,14 @@ class ModelTesterMixin:
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
assert actual_num_shards == expected_num_shards
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
@@ -1541,7 +1551,7 @@ class ModelTesterMixin:
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
assert torch.allclose(base_output[0], new_output[0], atol=1e-5)
# This test is okay without a GPU because we're not running any execution. We're just serializing
# and check if the resultant files are following an expected format.
@@ -1559,14 +1569,14 @@ class ModelTesterMixin:
tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe
)
index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant)))
assert os.path.exists(os.path.join(tmp_dir, index_variant))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)])
self.assertTrue(actual_num_shards == expected_num_shards)
assert actual_num_shards == expected_num_shards
# Check if the variant is present as a substring in the checkpoints.
shard_files = [
@@ -1634,9 +1644,9 @@ class ModelTesterMixin:
if any(re.search(pattern, name) for pattern in patterns_to_check):
dtype_to_check = compute_dtype
if getattr(submodule, "weight", None) is not None:
self.assertEqual(submodule.weight.dtype, dtype_to_check)
assert submodule.weight.dtype == dtype_to_check
if getattr(submodule, "bias", None) is not None:
self.assertEqual(submodule.bias.dtype, dtype_to_check)
assert submodule.bias.dtype == dtype_to_check
def test_layerwise_casting(storage_dtype, compute_dtype):
torch.manual_seed(0)
@@ -1651,7 +1661,7 @@ class ModelTesterMixin:
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
# We just want to make sure that the layerwise casting is working as expected.
self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0)
assert numpy_cosine_similarity_distance(base_slice, output) < 1.0
test_layerwise_casting(torch.float16, torch.float32)
test_layerwise_casting(torch.float8_e4m3fn, torch.float32)
@@ -1692,15 +1702,15 @@ class ModelTesterMixin:
)
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
self.assertTrue(
assert (
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)
@@ -1716,12 +1726,10 @@ class ModelTesterMixin:
@torch.no_grad()
def run_forward(model):
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
model.eval()
return model(**inputs_dict)[0]
@@ -1753,10 +1761,10 @@ class ModelTesterMixin:
)
output_with_group_offloading4 = run_forward(model)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@require_torch_accelerator
@@ -1838,7 +1846,7 @@ class ModelTesterMixin:
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
self.assertTrue(has_safetensors, "No safetensors found in the directory.")
assert has_safetensors, "No safetensors found in the directory."
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
@@ -1856,7 +1864,7 @@ class ModelTesterMixin:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol)
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
@@ -1895,10 +1903,8 @@ class ModelTesterMixin:
output_auto = output_auto.to_tuple()[0]
max_diff = (output_original - output_auto).abs().max().item()
self.assertLessEqual(
max_diff,
expected_max_diff,
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}",
assert max_diff <= expected_max_diff, (
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}"
)
@parameterized.expand(
@@ -1912,7 +1918,7 @@ class ModelTesterMixin:
model = self.model_class(**init_dict)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
with self.assertRaises(ValueError) as err_ctx:
with pytest.raises(ValueError) as err_ctx:
_ = self.model_class.from_pretrained(tmpdir, device_map=device_map)
assert msg_substring in str(err_ctx.exception)
@@ -1942,7 +1948,7 @@ class ModelTesterMixin:
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
class TestModelPushToHub:
identifier = uuid.uuid4()
repo_id = f"test-model-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
@@ -1962,7 +1968,7 @@ class ModelPushToHubTester(unittest.TestCase):
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
assert torch.equal(p1, p2)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
@@ -1973,7 +1979,7 @@ class ModelPushToHubTester(unittest.TestCase):
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
assert torch.equal(p1, p2)
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
@@ -1993,7 +1999,7 @@ class ModelPushToHubTester(unittest.TestCase):
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
assert torch.equal(p1, p2)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
@@ -2004,12 +2010,12 @@ class ModelPushToHubTester(unittest.TestCase):
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
assert torch.equal(p1, p2)
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
@unittest.skipIf(
@pytest.mark.skipif(
not is_jinja_available(),
reason="Model card tests cannot be performed without Jinja installed.",
)
@@ -2296,7 +2302,7 @@ class LoraHotSwappingForModelTesterMixin:
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with self.assertRaisesRegex(ValueError, msg):
with pytest.raises(ValueError, match=msg):
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
@@ -2357,10 +2363,10 @@ class LoraHotSwappingForModelTesterMixin:
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with self.assertRaisesRegex(RuntimeError, msg):
with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
# ensure that enable_lora_hotswap is called before loading the first adapter
from diffusers.loaders.peft import logger
@@ -2371,9 +2377,11 @@ class LoraHotSwappingForModelTesterMixin:
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with self.assertLogs(logger=logger, level="WARNING") as cm:
import logging
with caplog.at_level(logging.WARNING, logger=logger.name):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in log for log in cm.output)
assert any(msg in record.message for record in caplog.records)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
@@ -2384,7 +2392,7 @@ class LoraHotSwappingForModelTesterMixin:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
assert len(w) == 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}"
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
@@ -2393,22 +2401,24 @@ class LoraHotSwappingForModelTesterMixin:
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with self.assertRaisesRegex(ValueError, msg):
with pytest.raises(ValueError, match=msg):
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self):
def test_hotswap_second_adapter_targets_more_layers_raises(self, caplog):
# check the error and log
from diffusers.loaders.peft import logger
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with self.assertRaises(RuntimeError): # peft raises RuntimeError
with self.assertLogs(logger=logger, level="ERROR") as cm:
with pytest.raises(RuntimeError): # peft raises RuntimeError
import logging
with caplog.at_level(logging.ERROR, logger=logger.name):
self.check_model_hotswap(
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@parameterized.expand([(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")