mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 03:14:43 +08:00
Compare commits
1 Commits
qwenimage-
...
torchao-of
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03584a8174 |
@@ -73,7 +73,7 @@ if is_torchao_available():
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_torchao_version_greater_or_equal("0.7.0")
|
@require_torchao_version_greater_or_equal("0.14.0")
|
||||||
class TorchAoConfigTest(unittest.TestCase):
|
class TorchAoConfigTest(unittest.TestCase):
|
||||||
def test_to_dict(self):
|
def test_to_dict(self):
|
||||||
"""
|
"""
|
||||||
@@ -131,7 +131,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
|||||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_torchao_version_greater_or_equal("0.7.0")
|
@require_torchao_version_greater_or_equal("0.14.0")
|
||||||
class TorchAoTest(unittest.TestCase):
|
class TorchAoTest(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -540,7 +540,7 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_torchao_version_greater_or_equal("0.7.0")
|
@require_torchao_version_greater_or_equal("0.14.0")
|
||||||
class TorchAoSerializationTest(unittest.TestCase):
|
class TorchAoSerializationTest(unittest.TestCase):
|
||||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||||
|
|
||||||
@@ -651,23 +651,22 @@ class TorchAoSerializationTest(unittest.TestCase):
|
|||||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||||
|
|
||||||
|
|
||||||
@require_torchao_version_greater_or_equal("0.7.0")
|
@require_torchao_version_greater_or_equal("0.14.0")
|
||||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||||
@property
|
@property
|
||||||
def quantization_config(self):
|
def quantization_config(self):
|
||||||
return PipelineQuantizationConfig(
|
return PipelineQuantizationConfig(
|
||||||
quant_mapping={
|
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())},
|
||||||
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
|
||||||
"when compiling."
|
|
||||||
)
|
|
||||||
def test_torch_compile_with_cpu_offload(self):
|
def test_torch_compile_with_cpu_offload(self):
|
||||||
|
pipe = self._init_pipeline(self.quantization_config, torch.bfloat16)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
# No compilation because it fails with:
|
||||||
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||||
super().test_torch_compile_with_cpu_offload()
|
|
||||||
|
# small resolutions to ensure speedy execution.
|
||||||
|
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||||
|
|
||||||
@parameterized.expand([False, True])
|
@parameterized.expand([False, True])
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
@@ -698,7 +697,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
|||||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_torchao_version_greater_or_equal("0.7.0")
|
@require_torchao_version_greater_or_equal("0.14.0")
|
||||||
@slow
|
@slow
|
||||||
@nightly
|
@nightly
|
||||||
class SlowTorchAoTests(unittest.TestCase):
|
class SlowTorchAoTests(unittest.TestCase):
|
||||||
@@ -857,7 +856,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_torchao_version_greater_or_equal("0.7.0")
|
@require_torchao_version_greater_or_equal("0.14.0")
|
||||||
@slow
|
@slow
|
||||||
@nightly
|
@nightly
|
||||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user