Compare commits

...

2 Commits

Author SHA1 Message Date
sayakpaul
3af9470115 up 2025-12-30 14:09:56 +05:30
Vasiliy Kuznetsov
1cdb8723b8 fix torchao quantizer for new torchao versions (#12901)
* fix torchao quantizer for new torchao versions

Summary:

`torchao==0.16.0` (not yet released) has some bc-breaking changes, this
PR fixes the diffusers repo with those changes. Specifics on the
changes:
1. `UInt4Tensor` is removed: https://github.com/pytorch/ao/pull/3536
2. old float8 tensors v1 are removed: https://github.com/pytorch/ao/pull/3510

In this PR:
1. move the logger variable up (not sure why it was in the middle of the
   file before) to get better error messages
2. gate the old torchao objects by torchao version

Test Plan:

import diffusers objects with new versions of torchao works:

```bash
> python -c "import torchao; print(torchao.__version__); from diffusers import StableDiffusionPipeline"
0.16.0.dev20251229+cu129
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-12-30 10:04:54 +05:30
2 changed files with 16 additions and 8 deletions

View File

@@ -263,8 +263,8 @@ def main():
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)

View File

@@ -36,6 +36,9 @@ from ...utils import (
from ..base import DiffusersQuantizer
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
@@ -83,11 +86,19 @@ def _update_torch_safe_globals():
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor])
# note: is_torchao_version(">=", "0.16.0") does not work correctly
# with torchao nightly, so using a ">" check which does work correctly
if is_torchao_version(">", "0.15.0"):
pass
else:
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl])
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
@@ -123,9 +134,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
return None
logger = logging.get_logger(__name__)
def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor