mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 11:24:43 +08:00
Compare commits
2 Commits
torchao-lo
...
torchao-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
568dbaa5cf | ||
|
|
fa273fd179 |
@@ -41,8 +41,6 @@ from ..utils import (
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
@@ -61,38 +59,6 @@ if is_accelerate_available():
|
||||
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
|
||||
|
||||
|
||||
def _update_torch_safe_globals():
|
||||
safe_globals = [
|
||||
(torch.uint1, "torch.uint1"),
|
||||
(torch.uint2, "torch.uint2"),
|
||||
(torch.uint3, "torch.uint3"),
|
||||
(torch.uint4, "torch.uint4"),
|
||||
(torch.uint5, "torch.uint5"),
|
||||
(torch.uint6, "torch.uint6"),
|
||||
(torch.uint7, "torch.uint7"),
|
||||
]
|
||||
try:
|
||||
from torchao.dtypes import NF4Tensor
|
||||
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
|
||||
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
|
||||
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
|
||||
|
||||
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
|
||||
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning(
|
||||
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
|
||||
)
|
||||
logger.debug(e)
|
||||
|
||||
finally:
|
||||
torch.serialization.add_safe_globals(safe_globals=safe_globals)
|
||||
|
||||
|
||||
if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
|
||||
# Adapted from `transformers` (see modeling_utils.py)
|
||||
def _determine_device_map(
|
||||
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
|
||||
|
||||
@@ -23,7 +23,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
|
||||
from ...utils import (
|
||||
get_module_from_name,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_torchao_version,
|
||||
is_torchao_available,
|
||||
logging,
|
||||
)
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
@@ -62,6 +69,38 @@ if is_torchao_available():
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
|
||||
def _update_torch_safe_globals():
|
||||
safe_globals = [
|
||||
(torch.uint1, "torch.uint1"),
|
||||
(torch.uint2, "torch.uint2"),
|
||||
(torch.uint3, "torch.uint3"),
|
||||
(torch.uint4, "torch.uint4"),
|
||||
(torch.uint5, "torch.uint5"),
|
||||
(torch.uint6, "torch.uint6"),
|
||||
(torch.uint7, "torch.uint7"),
|
||||
]
|
||||
try:
|
||||
from torchao.dtypes import NF4Tensor
|
||||
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
|
||||
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
|
||||
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
|
||||
|
||||
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
|
||||
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning(
|
||||
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
|
||||
)
|
||||
logger.debug(e)
|
||||
|
||||
finally:
|
||||
torch.serialization.add_safe_globals(safe_globals=safe_globals)
|
||||
|
||||
|
||||
if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user