Compare commits

...

1 Commits

Author SHA1 Message Date
Aryan
207fb07977 update 2025-02-18 17:36:06 +01:00

View File

@@ -1507,6 +1507,49 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
def test_error_when_disk_offload_run_together_with_group_offloading(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model1 = self.model_class(**config).eval()
model1 = model1.to(torch_device)
def has_accelerate_hooks(module):
from accelerate.hooks import AlignDevicesHook, CpuOffload
count = 0
for name, submodule in module.named_modules():
if not hasattr(submodule, "_hf_hook"):
continue
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
print(f"Found {name} with hook {submodule._hf_hook}")
count += 1
return count > 0
model_size = compute_module_sizes(model1)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model1.cpu().save_pretrained(tmp_dir)
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
)
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
assert has_accelerate_hooks(new_model)
del model1
torch.cuda.synchronize()
torch.cpu.synchronize()
torch.cuda.empty_cache()
gc.collect()
model2 = self.model_class(**config)
# ===============================
# We still have accelerate hooks on model2 in some cases?????
assert has_accelerate_hooks(model2)
# ===============================
model2.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):