Compare commits

...

2 Commits

Author SHA1 Message Date
YiYi Xu
7eb2d2208e Merge branch 'main' into fix-test 2024-03-31 22:07:28 -10:00
yiyixu
d97bca56ab fix 2024-04-01 07:52:45 +00:00

View File

@@ -1144,20 +1144,24 @@ class PipelineTesterMixin:
self.assertLess(
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
)
offloaded_modules = [
v
offloaded_modules = {
k: v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
]
(
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
}
self.assertTrue(
all(v.device.type == "cpu" for v in offloaded_modules.values()),
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
)
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
(
self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)),
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}",
offloaded_modules_with_incorrect_hooks = {}
for k, v in offloaded_modules.items():
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
self.assertTrue(
len(offloaded_modules_with_incorrect_hooks) == 0,
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
)
@unittest.skipIf(
@@ -1189,22 +1193,23 @@ class PipelineTesterMixin:
self.assertLess(
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
)
offloaded_modules = [
v
offloaded_modules = {
k: v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
]
(
self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)),
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}",
}
self.assertTrue(
all(v.device.type == "meta" for v in offloaded_modules.values()),
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
)
offloaded_modules_with_incorrect_hooks = {}
for k, v in offloaded_modules.items():
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
(
self.assertTrue(
all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks)
),
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}",
self.assertTrue(
len(offloaded_modules_with_incorrect_hooks) == 0,
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
)
@unittest.skipIf(