Compare commits

...

1 Commits

Author SHA1 Message Date
Marc Sun
d7fa445453 Remove 8bit device restriction (#12972)
* allow to

* update version

* fix version again

* again

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* style

* xfail

* add pr

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-14 20:33:30 +05:30
3 changed files with 34 additions and 23 deletions

View File

@@ -1360,12 +1360,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
raise ValueError(
"Calling `cuda()` is not supported for `8-bit` quantized models. "
" Please use the model as it is, since the model has already been set to the correct devices."
"Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
)
elif is_bitsandbytes_version("<", "0.43.2"):
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
@@ -1412,17 +1412,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
)
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
"Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
)
elif is_bitsandbytes_version("<", "0.43.2"):
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
raise ValueError(
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
logger.warning(
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."

View File

@@ -60,6 +60,7 @@ from ..utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
is_bitsandbytes_version,
is_hpu_available,
is_torch_npu_available,
is_torch_version,
@@ -444,7 +445,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
if is_loaded_in_8bit_bnb:
# https://github.com/huggingface/accelerate/pull/3907
if is_loaded_in_8bit_bnb and (
is_bitsandbytes_version("<", "0.48.0") or is_accelerate_version("<", "1.13.0.dev0")
):
return False
return hasattr(module, "_hf_hook") and (
@@ -523,9 +527,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
)
if is_loaded_in_8bit_bnb and device is not None:
if is_loaded_in_8bit_bnb and device is not None and is_bitsandbytes_version("<", "0.48.0"):
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
"You need to upgrade bitsandbytes to at least 0.48.0"
)
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
@@ -542,6 +547,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
# added here https://github.com/huggingface/transformers/pull/43258
if (
is_loaded_in_8bit_bnb
and device is not None
and is_transformers_version(">", "4.58.0")
and is_bitsandbytes_version(">=", "0.48.0")
):
module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
module.to(device, dtype)
@@ -1223,7 +1236,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# This is because the model would already be placed on a CUDA device.
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
if is_loaded_in_8bit_bnb:
if is_loaded_in_8bit_bnb and (
is_transformers_version("<", "4.58.0") or is_bitsandbytes_version("<", "0.48.0")
):
logger.info(
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
)

View File

@@ -288,31 +288,29 @@ class BnB8bitBasicTests(Base8bitTests):
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
@require_bitsandbytes_version_greater("0.48.0")
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Checks also if other models are casted correctly.
"""
with self.assertRaises(ValueError):
# Tries with `str`
self.model_8bit.to("cpu")
with self.assertRaises(ValueError):
# Tries with a `dtype``
self.model_8bit.to(torch.float16)
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.to(torch.device(f"{torch_device}:0"))
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.float()
with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a `dtype`
self.model_8bit.half()
# This should work with 0.48.0
self.model_8bit.to("cpu")
self.model_8bit.to(torch.device(f"{torch_device}:0"))
# Test if we did not break anything
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
input_dict_for_transformer = self.get_dummy_inputs()
@@ -837,7 +835,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0")
@require_bitsandbytes_version_greater("0.45.5")
@require_bitsandbytes_version_greater("0.48.0")
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
@@ -848,7 +846,7 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
)
@pytest.mark.xfail(
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
reason="Test fails because of a type change when recompiling."
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
)
def test_torch_compile(self):
@@ -858,6 +856,5 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)