Compare commits

...

2 Commits

Author SHA1 Message Date
DN6
568dbaa5cf update 2025-03-10 16:52:24 +05:30
DN6
fa273fd179 update 2025-03-10 16:47:18 +05:30
2 changed files with 40 additions and 35 deletions

View File

@@ -41,8 +41,6 @@ from ..utils import (
is_gguf_available,
is_torch_available,
is_torch_version,
is_torchao_available,
is_torchao_version,
logging,
)
@@ -61,38 +59,6 @@ if is_accelerate_available():
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
def _update_torch_safe_globals():
safe_globals = [
(torch.uint1, "torch.uint1"),
(torch.uint2, "torch.uint2"),
(torch.uint3, "torch.uint3"),
(torch.uint4, "torch.uint4"),
(torch.uint5, "torch.uint5"),
(torch.uint6, "torch.uint6"),
(torch.uint7, "torch.uint7"),
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
)
logger.debug(e)
finally:
torch.serialization.add_safe_globals(safe_globals=safe_globals)
if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
_update_torch_safe_globals()
# Adapted from `transformers` (see modeling_utils.py)
def _determine_device_map(
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None

View File

@@ -23,7 +23,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from packaging import version
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
from ...utils import (
get_module_from_name,
is_torch_available,
is_torch_version,
is_torchao_version,
is_torchao_available,
logging,
)
from ..base import DiffusersQuantizer
@@ -62,6 +69,38 @@ if is_torchao_available():
from torchao.quantization import quantize_
def _update_torch_safe_globals():
safe_globals = [
(torch.uint1, "torch.uint1"),
(torch.uint2, "torch.uint2"),
(torch.uint3, "torch.uint3"),
(torch.uint4, "torch.uint4"),
(torch.uint5, "torch.uint5"),
(torch.uint6, "torch.uint6"),
(torch.uint7, "torch.uint7"),
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
)
logger.debug(e)
finally:
torch.serialization.add_safe_globals(safe_globals=safe_globals)
if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
_update_torch_safe_globals()
logger = logging.get_logger(__name__)