Compare commits

...

1 Commits

Author SHA1 Message Date
Vasiliy Kuznetsov
1cdb8723b8 fix torchao quantizer for new torchao versions (#12901)
* fix torchao quantizer for new torchao versions

Summary:

`torchao==0.16.0` (not yet released) has some bc-breaking changes, this
PR fixes the diffusers repo with those changes. Specifics on the
changes:
1. `UInt4Tensor` is removed: https://github.com/pytorch/ao/pull/3536
2. old float8 tensors v1 are removed: https://github.com/pytorch/ao/pull/3510

In this PR:
1. move the logger variable up (not sure why it was in the middle of the
   file before) to get better error messages
2. gate the old torchao objects by torchao version

Test Plan:

import diffusers objects with new versions of torchao works:

```bash
> python -c "import torchao; print(torchao.__version__); from diffusers import StableDiffusionPipeline"
0.16.0.dev20251229+cu129
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-12-30 10:04:54 +05:30

View File

@@ -36,6 +36,9 @@ from ...utils import (
from ..base import DiffusersQuantizer
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
@@ -83,11 +86,19 @@ def _update_torch_safe_globals():
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor])
# note: is_torchao_version(">=", "0.16.0") does not work correctly
# with torchao nightly, so using a ">" check which does work correctly
if is_torchao_version(">", "0.15.0"):
pass
else:
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl])
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
@@ -123,9 +134,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
return None
logger = logging.get_logger(__name__)
def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor