mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-29 20:07:48 +08:00
Compare commits
2 Commits
revisit-de
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2be8bd6b3 | ||
|
|
7da22b9db5 |
3
.github/workflows/claude_review.yml
vendored
3
.github/workflows/claude_review.yml
vendored
@@ -32,6 +32,9 @@ jobs:
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
1
.github/workflows/pr_dependency_test.yml
vendored
1
.github/workflows/pr_dependency_test.yml
vendored
@@ -6,7 +6,6 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
@@ -6,7 +6,6 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
@@ -27,7 +26,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
pip install torch pytest
|
||||
pip install torch torchvision torchaudio pytest
|
||||
- name: Check for soft dependencies
|
||||
run: |
|
||||
pytest tests/others/test_dependencies.py
|
||||
|
||||
@@ -5,13 +5,10 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import normalize, resize
|
||||
|
||||
from ...utils import get_logger, is_torchvision_available, load_image
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import normalize, resize
|
||||
from ...utils import get_logger, load_image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -470,8 +470,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
if is_torchao_version("<=", "0.9.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
|
||||
if is_torchao_version("<", "0.15.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
|
||||
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
@@ -495,8 +495,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
"""Create configuration from a dictionary."""
|
||||
if not is_torchao_version(">", "0.9.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
|
||||
if not is_torchao_version(">=", "0.15.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ if (
|
||||
is_torch_available()
|
||||
and is_torch_version(">=", "2.6.0")
|
||||
and is_torchao_available()
|
||||
and is_torchao_version(">=", "0.7.0")
|
||||
and is_torchao_version(">=", "0.15.0")
|
||||
):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
@@ -168,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
raise ImportError(
|
||||
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
|
||||
)
|
||||
torchao_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torchao_version < version.parse("0.7.0"):
|
||||
torchao_version = version.parse(importlib.metadata.version("torchao"))
|
||||
if torchao_version < version.parse("0.15.0"):
|
||||
raise RuntimeError(
|
||||
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
)
|
||||
|
||||
self.offload = False
|
||||
|
||||
@@ -13,14 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDependencies:
|
||||
class DependencyTester(unittest.TestCase):
|
||||
def test_diffusers_import(self):
|
||||
import diffusers # noqa: F401
|
||||
try:
|
||||
import diffusers # noqa: F401
|
||||
except ImportError:
|
||||
assert False
|
||||
|
||||
def test_backend_registration(self):
|
||||
import diffusers
|
||||
@@ -50,36 +52,3 @@ class TestDependencies:
|
||||
if hasattr(diffusers.pipelines, cls_name):
|
||||
pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3])
|
||||
_ = import_module(pipeline_folder_module, str(cls_name))
|
||||
|
||||
def test_pipeline_module_imports(self):
|
||||
"""Import every pipeline submodule whose dependencies are satisfied,
|
||||
to catch unguarded optional-dep imports (e.g., torchvision).
|
||||
|
||||
Uses inspect.getmembers to discover classes that the lazy loader can
|
||||
actually resolve (same self-filtering as test_pipeline_imports), then
|
||||
imports the full module path instead of truncating to the folder level.
|
||||
"""
|
||||
import diffusers
|
||||
import diffusers.pipelines
|
||||
|
||||
failures = []
|
||||
all_classes = inspect.getmembers(diffusers, inspect.isclass)
|
||||
|
||||
for cls_name, cls_module in all_classes:
|
||||
if not hasattr(diffusers.pipelines, cls_name):
|
||||
continue
|
||||
if "dummy_" in cls_module.__module__:
|
||||
continue
|
||||
|
||||
full_module_path = cls_module.__module__
|
||||
try:
|
||||
import_module(full_module_path)
|
||||
except ImportError as e:
|
||||
failures.append(f"{full_module_path}: {e}")
|
||||
except Exception:
|
||||
# Non-import errors (e.g., missing config) are fine; we only
|
||||
# care about unguarded import statements.
|
||||
pass
|
||||
|
||||
if failures:
|
||||
pytest.fail("Unguarded optional-dependency imports found:\n" + "\n".join(failures))
|
||||
|
||||
@@ -14,13 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -82,18 +80,17 @@ if is_torchao_available():
|
||||
Float8WeightOnlyConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8DynamicActivationIntxWeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
IntxWeightOnlyConfig,
|
||||
)
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
|
||||
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@@ -128,7 +125,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
@@ -527,7 +524,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
def test_aobase_config(self):
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
@@ -540,7 +537,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@@ -650,7 +647,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -696,7 +693,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
@@ -854,7 +851,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user