mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-14 23:55:41 +08:00
Compare commits
1 Commits
transforme
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7fa445453 |
@@ -1360,12 +1360,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `8-bit` quantized models. "
|
||||
" Please use the model as it is, since the model has already been set to the correct devices."
|
||||
"Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
|
||||
)
|
||||
elif is_bitsandbytes_version("<", "0.43.2"):
|
||||
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
@@ -1412,17 +1412,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
)
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
|
||||
raise ValueError(
|
||||
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||
"Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
|
||||
)
|
||||
elif is_bitsandbytes_version("<", "0.43.2"):
|
||||
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
|
||||
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
|
||||
logger.warning(
|
||||
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
|
||||
|
||||
@@ -60,6 +60,7 @@ from ..utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bitsandbytes_version,
|
||||
is_hpu_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
@@ -444,7 +445,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
||||
|
||||
if is_loaded_in_8bit_bnb:
|
||||
# https://github.com/huggingface/accelerate/pull/3907
|
||||
if is_loaded_in_8bit_bnb and (
|
||||
is_bitsandbytes_version("<", "0.48.0") or is_accelerate_version("<", "1.13.0.dev0")
|
||||
):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and (
|
||||
@@ -523,9 +527,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
|
||||
)
|
||||
|
||||
if is_loaded_in_8bit_bnb and device is not None:
|
||||
if is_loaded_in_8bit_bnb and device is not None and is_bitsandbytes_version("<", "0.48.0"):
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
|
||||
"You need to upgrade bitsandbytes to at least 0.48.0"
|
||||
)
|
||||
|
||||
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
|
||||
@@ -542,6 +547,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
|
||||
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
|
||||
module.to(device=device)
|
||||
# added here https://github.com/huggingface/transformers/pull/43258
|
||||
if (
|
||||
is_loaded_in_8bit_bnb
|
||||
and device is not None
|
||||
and is_transformers_version(">", "4.58.0")
|
||||
and is_bitsandbytes_version(">=", "0.48.0")
|
||||
):
|
||||
module.to(device=device)
|
||||
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
|
||||
module.to(device, dtype)
|
||||
|
||||
@@ -1223,7 +1236,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# This is because the model would already be placed on a CUDA device.
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
|
||||
if is_loaded_in_8bit_bnb:
|
||||
if is_loaded_in_8bit_bnb and (
|
||||
is_transformers_version("<", "4.58.0") or is_bitsandbytes_version("<", "0.48.0")
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
|
||||
)
|
||||
|
||||
@@ -288,31 +288,29 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
|
||||
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||
|
||||
@require_bitsandbytes_version_greater("0.48.0")
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `str`
|
||||
self.model_8bit.to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
self.model_8bit.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.to(torch.device(f"{torch_device}:0"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
# Tries with a `dtype`
|
||||
self.model_8bit.half()
|
||||
|
||||
# This should work with 0.48.0
|
||||
self.model_8bit.to("cpu")
|
||||
self.model_8bit.to(torch.device(f"{torch_device}:0"))
|
||||
|
||||
# Test if we did not break anything
|
||||
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
|
||||
input_dict_for_transformer = self.get_dummy_inputs()
|
||||
@@ -837,7 +835,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
@require_bitsandbytes_version_greater("0.45.5")
|
||||
@require_bitsandbytes_version_greater("0.48.0")
|
||||
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -848,7 +846,7 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
|
||||
reason="Test fails because of a type change when recompiling."
|
||||
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
|
||||
)
|
||||
def test_torch_compile(self):
|
||||
@@ -858,6 +856,5 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
|
||||
|
||||
Reference in New Issue
Block a user