mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-23 21:04:56 +08:00
Compare commits
2 Commits
torchao-lo
...
torchao-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
568dbaa5cf | ||
|
|
fa273fd179 |
@@ -41,8 +41,6 @@ from ..utils import (
|
|||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_version,
|
is_torch_version,
|
||||||
is_torchao_available,
|
|
||||||
is_torchao_version,
|
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -61,38 +59,6 @@ if is_accelerate_available():
|
|||||||
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
|
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
|
||||||
|
|
||||||
|
|
||||||
def _update_torch_safe_globals():
|
|
||||||
safe_globals = [
|
|
||||||
(torch.uint1, "torch.uint1"),
|
|
||||||
(torch.uint2, "torch.uint2"),
|
|
||||||
(torch.uint3, "torch.uint3"),
|
|
||||||
(torch.uint4, "torch.uint4"),
|
|
||||||
(torch.uint5, "torch.uint5"),
|
|
||||||
(torch.uint6, "torch.uint6"),
|
|
||||||
(torch.uint7, "torch.uint7"),
|
|
||||||
]
|
|
||||||
try:
|
|
||||||
from torchao.dtypes import NF4Tensor
|
|
||||||
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
|
|
||||||
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
|
|
||||||
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
|
|
||||||
|
|
||||||
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
|
|
||||||
|
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
|
||||||
logger.warning(
|
|
||||||
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
|
|
||||||
)
|
|
||||||
logger.debug(e)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
torch.serialization.add_safe_globals(safe_globals=safe_globals)
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
|
|
||||||
_update_torch_safe_globals()
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from `transformers` (see modeling_utils.py)
|
# Adapted from `transformers` (see modeling_utils.py)
|
||||||
def _determine_device_map(
|
def _determine_device_map(
|
||||||
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
|
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
|
||||||
|
|||||||
@@ -23,7 +23,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
|
from ...utils import (
|
||||||
|
get_module_from_name,
|
||||||
|
is_torch_available,
|
||||||
|
is_torch_version,
|
||||||
|
is_torchao_version,
|
||||||
|
is_torchao_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from ..base import DiffusersQuantizer
|
from ..base import DiffusersQuantizer
|
||||||
|
|
||||||
|
|
||||||
@@ -62,6 +69,38 @@ if is_torchao_available():
|
|||||||
from torchao.quantization import quantize_
|
from torchao.quantization import quantize_
|
||||||
|
|
||||||
|
|
||||||
|
def _update_torch_safe_globals():
|
||||||
|
safe_globals = [
|
||||||
|
(torch.uint1, "torch.uint1"),
|
||||||
|
(torch.uint2, "torch.uint2"),
|
||||||
|
(torch.uint3, "torch.uint3"),
|
||||||
|
(torch.uint4, "torch.uint4"),
|
||||||
|
(torch.uint5, "torch.uint5"),
|
||||||
|
(torch.uint6, "torch.uint6"),
|
||||||
|
(torch.uint7, "torch.uint7"),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
from torchao.dtypes import NF4Tensor
|
||||||
|
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
|
||||||
|
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
|
||||||
|
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
|
||||||
|
|
||||||
|
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
|
||||||
|
|
||||||
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
|
||||||
|
)
|
||||||
|
logger.debug(e)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
torch.serialization.add_safe_globals(safe_globals=safe_globals)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
|
||||||
|
_update_torch_safe_globals()
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user