mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 03:14:43 +08:00
Compare commits
4 Commits
edit-pypi-
...
sanitize-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
045f3ade68 | ||
|
|
7ea98accb0 | ||
|
|
43cae1a613 | ||
|
|
1d9bf41cf9 |
@@ -14,6 +14,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import ModelCard, delete_repo
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from parameterized import parameterized
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
@@ -32,7 +33,6 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
apply_faster_cache,
|
||||
)
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
@@ -2244,80 +2244,6 @@ class PipelineTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference(self):
|
||||
if not self.test_group_offloading:
|
||||
return
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
|
||||
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
|
||||
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
|
||||
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
|
||||
# warmup forward pass (even with dummy small inputs) is recommended.
|
||||
for component_name in [
|
||||
"text_encoder",
|
||||
"text_encoder_2",
|
||||
"text_encoder_3",
|
||||
"transformer",
|
||||
"unet",
|
||||
"controlnet",
|
||||
]:
|
||||
if not hasattr(pipe, component_name):
|
||||
continue
|
||||
component = getattr(pipe, component_name)
|
||||
if not getattr(component, "_supports_group_offloading", True):
|
||||
continue
|
||||
if hasattr(component, "enable_group_offload"):
|
||||
# For diffusers ModelMixin implementations
|
||||
component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs)
|
||||
else:
|
||||
# For other models not part of diffusers
|
||||
apply_group_offloading(
|
||||
component, onload_device=torch.device(torch_device), **group_offloading_kwargs
|
||||
)
|
||||
self.assertTrue(
|
||||
all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in component.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
)
|
||||
)
|
||||
for component_name in ["vae", "vqvae", "image_encoder"]:
|
||||
component = getattr(pipe, component_name, None)
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to(torch_device)
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
pipe = create_pipe().to(torch_device)
|
||||
output_without_group_offloading = run_forward(pipe)
|
||||
|
||||
pipe = create_pipe()
|
||||
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
|
||||
output_with_group_offloading1 = run_forward(pipe)
|
||||
|
||||
pipe = create_pipe()
|
||||
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
|
||||
output_with_group_offloading2 = run_forward(pipe)
|
||||
|
||||
if torch.is_tensor(output_without_group_offloading):
|
||||
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
|
||||
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
|
||||
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
|
||||
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
|
||||
|
||||
def test_torch_dtype_dict(self):
|
||||
components = self.get_dummy_components()
|
||||
if not components:
|
||||
@@ -2364,7 +2290,7 @@ class PipelineTesterMixin:
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_pipeline_level_group_offloading_sanity_checks(self):
|
||||
def test_group_offloading_sanity_checks(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: DiffusionPipeline = self.pipeline_class(**components)
|
||||
|
||||
@@ -2394,41 +2320,61 @@ class PipelineTesterMixin:
|
||||
component_device = next(component.parameters())[0].device
|
||||
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
|
||||
|
||||
@parameterized.expand([("block_level"), ("leaf_level")])
|
||||
@require_torch_accelerator
|
||||
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe: DiffusionPipeline = self.pipeline_class(**components)
|
||||
def test_group_offloading_inference(self, offload_type: str = "block_level"):
|
||||
if not self.test_group_offloading:
|
||||
pytest.skip("`test_group_offloading` is disabled hence skipping.")
|
||||
|
||||
for name, component in pipe.components.items():
|
||||
if hasattr(component, "_supports_group_offloading"):
|
||||
if not component._supports_group_offloading:
|
||||
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
# Regular inference.
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = torch.manual_seed(0)
|
||||
out = pipe(**inputs)[0]
|
||||
def enable_group_offload(pipe, group_offloading_kwargs):
|
||||
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
|
||||
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
|
||||
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
|
||||
# warmup forward pass (even with dummy small inputs) is recommended.
|
||||
exclude_modules = {"vae", "vqvae", "image_encoder"}
|
||||
exclude_modules = list(exclude_modules & set(pipe.components.keys()))
|
||||
pipe.enable_group_offload(
|
||||
exclude_modules=exclude_modules, onload_device=torch_device, **group_offloading_kwargs
|
||||
)
|
||||
for component_name, component in pipe.components.items():
|
||||
if component_name in exclude_modules:
|
||||
continue
|
||||
elif isinstance(component, torch.nn.Module):
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in component.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
)
|
||||
|
||||
pipe.to("cpu")
|
||||
del pipe
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# Inference with offloading
|
||||
pipe: DiffusionPipeline = self.pipeline_class(**components)
|
||||
offload_device = "cpu"
|
||||
pipe.enable_group_offload(
|
||||
onload_device=torch_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs["generator"] = torch.manual_seed(0)
|
||||
out_offload = pipe(**inputs)[0]
|
||||
pipe = create_pipe().to(torch_device)
|
||||
output_without_group_offloading = run_forward(pipe)
|
||||
|
||||
max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
pipe = create_pipe()
|
||||
if offload_type == "block_level":
|
||||
offloading_kwargs = {"offload_type": "block_level", "num_blocks_per_group": 1}
|
||||
else:
|
||||
offloading_kwargs = {"offload_type": "leaf_level"}
|
||||
enable_group_offload(pipe, offloading_kwargs)
|
||||
|
||||
output_with_group_offloading = run_forward(pipe)
|
||||
|
||||
if torch.is_tensor(output_without_group_offloading):
|
||||
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
|
||||
output_with_group_offloading = output_with_group_offloading.detach().cpu().numpy()
|
||||
|
||||
assert np.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user