mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-13 16:04:41 +08:00
Compare commits
2 Commits
single-fil
...
fix-test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7eb2d2208e | ||
|
|
d97bca56ab |
@@ -1144,20 +1144,24 @@ class PipelineTesterMixin:
|
|||||||
self.assertLess(
|
self.assertLess(
|
||||||
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
|
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
|
||||||
)
|
)
|
||||||
offloaded_modules = [
|
offloaded_modules = {
|
||||||
v
|
k: v
|
||||||
for k, v in pipe.components.items()
|
for k, v in pipe.components.items()
|
||||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||||
]
|
}
|
||||||
(
|
self.assertTrue(
|
||||||
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
|
all(v.device.type == "cpu" for v in offloaded_modules.values()),
|
||||||
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
|
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
|
||||||
)
|
)
|
||||||
|
|
||||||
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
|
offloaded_modules_with_incorrect_hooks = {}
|
||||||
(
|
for k, v in offloaded_modules.items():
|
||||||
self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)),
|
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
|
||||||
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}",
|
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
len(offloaded_modules_with_incorrect_hooks) == 0,
|
||||||
|
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
@@ -1189,22 +1193,23 @@ class PipelineTesterMixin:
|
|||||||
self.assertLess(
|
self.assertLess(
|
||||||
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
|
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
|
||||||
)
|
)
|
||||||
offloaded_modules = [
|
offloaded_modules = {
|
||||||
v
|
k: v
|
||||||
for k, v in pipe.components.items()
|
for k, v in pipe.components.items()
|
||||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||||
]
|
}
|
||||||
(
|
self.assertTrue(
|
||||||
self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)),
|
all(v.device.type == "meta" for v in offloaded_modules.values()),
|
||||||
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}",
|
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
|
||||||
)
|
)
|
||||||
|
offloaded_modules_with_incorrect_hooks = {}
|
||||||
|
for k, v in offloaded_modules.items():
|
||||||
|
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
|
||||||
|
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
||||||
|
|
||||||
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
|
self.assertTrue(
|
||||||
(
|
len(offloaded_modules_with_incorrect_hooks) == 0,
|
||||||
self.assertTrue(
|
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
|
||||||
all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks)
|
|
||||||
),
|
|
||||||
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
|
|||||||
Reference in New Issue
Block a user