mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-15 00:44:51 +08:00
Compare commits
19 Commits
diffusers-
...
kig/sd-vae
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4bc76ce75 | ||
|
|
a03d41f344 | ||
|
|
8e7b8c218e | ||
|
|
6adfedfabf | ||
|
|
94781b6e7d | ||
|
|
928c6d324f | ||
|
|
541f27517d | ||
|
|
20387d0319 | ||
|
|
2a403a173e | ||
|
|
307fd12567 | ||
|
|
2b0454d8ad | ||
|
|
0a96a8184f | ||
|
|
c99dbb6ff8 | ||
|
|
ac8b1c2eeb | ||
|
|
4b6536d49a | ||
|
|
14215bf3fd | ||
|
|
626fb88a9d | ||
|
|
63d5661b47 | ||
|
|
49b61c8e5b |
@@ -37,3 +37,5 @@ Available Checkpoints are:
|
|||||||
- disable_vae_slicing
|
- disable_vae_slicing
|
||||||
- enable_xformers_memory_efficient_attention
|
- enable_xformers_memory_efficient_attention
|
||||||
- disable_xformers_memory_efficient_attention
|
- disable_xformers_memory_efficient_attention
|
||||||
|
- enable_vae_tiling
|
||||||
|
- disable_vae_tiling
|
||||||
|
|||||||
@@ -133,6 +133,33 @@ images = pipe([prompt] * 32).images
|
|||||||
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
|
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
|
||||||
|
|
||||||
|
|
||||||
|
## Tiled VAE decode and encode for large images
|
||||||
|
|
||||||
|
Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image.
|
||||||
|
|
||||||
|
You want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
|
||||||
|
|
||||||
|
To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example:
|
||||||
|
|
||||||
|
```Python
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
pipe = pipe.to("cuda")
|
||||||
|
|
||||||
|
prompt = "a beautiful landscape photograph"
|
||||||
|
pipe.enable_vae_tiling()
|
||||||
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
images = pipe([prompt], width=3840, height=2224).images
|
||||||
|
```
|
||||||
|
|
||||||
|
The output image will have some tile-to-tile tone variation from the tiles having separate decoders, but you shouldn't see sharp seams between the tiles. The tiling is turned off for images that are 512x512 or smaller.
|
||||||
|
|
||||||
|
|
||||||
<a name="sequential_offloading"></a>
|
<a name="sequential_offloading"></a>
|
||||||
## Offloading to CPU with accelerate for memory savings
|
## Offloading to CPU with accelerate for memory savings
|
||||||
|
|
||||||
|
|||||||
@@ -109,8 +109,40 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||||
self.use_slicing = False
|
self.use_slicing = False
|
||||||
|
|
||||||
|
def enable_tiling(self, use_tiling: bool = True):
|
||||||
|
r"""
|
||||||
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||||
|
compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
|
||||||
|
the processing of larger images.
|
||||||
|
"""
|
||||||
|
self.use_tiling = use_tiling
|
||||||
|
|
||||||
|
def disable_tiling(self):
|
||||||
|
r"""
|
||||||
|
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
||||||
|
computing decoding in one step.
|
||||||
|
"""
|
||||||
|
self.enable_tiling(False)
|
||||||
|
|
||||||
|
def enable_slicing(self):
|
||||||
|
r"""
|
||||||
|
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||||
|
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||||
|
"""
|
||||||
|
self.use_slicing = True
|
||||||
|
|
||||||
|
def disable_slicing(self):
|
||||||
|
r"""
|
||||||
|
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
||||||
|
decoding in one step.
|
||||||
|
"""
|
||||||
|
self.use_slicing = False
|
||||||
|
|
||||||
@apply_forward_hook
|
@apply_forward_hook
|
||||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||||
|
if self.use_tiling and (x.shape[-1] > 512 or x.shape[-2] > 512):
|
||||||
|
return self.tiled_encode(x, return_dict=return_dict)
|
||||||
|
|
||||||
h = self.encoder(x)
|
h = self.encoder(x)
|
||||||
moments = self.quant_conv(h)
|
moments = self.quant_conv(h)
|
||||||
posterior = DiagonalGaussianDistribution(moments)
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
@@ -121,6 +153,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||||||
return AutoencoderKLOutput(latent_dist=posterior)
|
return AutoencoderKLOutput(latent_dist=posterior)
|
||||||
|
|
||||||
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
if self.use_tiling and (z.shape[-1] > 64 or z.shape[-2] > 64):
|
||||||
|
return self.tiled_decode(z, return_dict=return_dict)
|
||||||
|
|
||||||
z = self.post_quant_conv(z)
|
z = self.post_quant_conv(z)
|
||||||
dec = self.decoder(z)
|
dec = self.decoder(z)
|
||||||
|
|
||||||
@@ -129,22 +164,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
return DecoderOutput(sample=dec)
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
def enable_slicing(self):
|
|
||||||
r"""
|
|
||||||
Enable sliced VAE decoding.
|
|
||||||
|
|
||||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
|
||||||
steps. This is useful to save some memory and allow larger batch sizes.
|
|
||||||
"""
|
|
||||||
self.use_slicing = True
|
|
||||||
|
|
||||||
def disable_slicing(self):
|
|
||||||
r"""
|
|
||||||
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
|
||||||
decoding in one step.
|
|
||||||
"""
|
|
||||||
self.use_slicing = False
|
|
||||||
|
|
||||||
@apply_forward_hook
|
@apply_forward_hook
|
||||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
if self.use_slicing and z.shape[0] > 1:
|
if self.use_slicing and z.shape[0] > 1:
|
||||||
@@ -158,6 +177,100 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
return DecoderOutput(sample=decoded)
|
return DecoderOutput(sample=decoded)
|
||||||
|
|
||||||
|
def blend_v(self, a, b, blend_width):
|
||||||
|
for y in range(blend_width):
|
||||||
|
b[:, :, y, :] = a[:, :, -blend_width + y, :] * (1 - y / blend_width) + b[:, :, y, :] * (y / blend_width)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def blend_h(self, a, b, blend_width):
|
||||||
|
for x in range(blend_width):
|
||||||
|
b[:, :, :, x] = a[:, :, :, -blend_width + x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||||
|
r"""Encode a batch of images using a tiled encoder.
|
||||||
|
Args:
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||||
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
|
||||||
|
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
|
||||||
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||||
|
look of the output, but they should be much less noticeable.
|
||||||
|
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
# Split the image into 512x512 tiles and encode them separately.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, x.shape[2], 384):
|
||||||
|
row = []
|
||||||
|
for j in range(0, x.shape[3], 384):
|
||||||
|
tile = x[:, :, i : i + 512, j : j + 512]
|
||||||
|
tile = self.encoder(tile)
|
||||||
|
tile = self.quant_conv(tile)
|
||||||
|
row.append(tile)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, 16)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, 16)
|
||||||
|
result_row.append(tile[:, :, :48, :48])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
moments = torch.cat(result_rows, dim=2)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (posterior,)
|
||||||
|
|
||||||
|
return AutoencoderKLOutput(latent_dist=posterior)
|
||||||
|
|
||||||
|
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
r"""Decode a batch of images using a tiled decoder.
|
||||||
|
Args:
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
|
||||||
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
|
||||||
|
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
|
||||||
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||||
|
look of the output, but they should be much less noticeable.
|
||||||
|
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
|
||||||
|
`True`):
|
||||||
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||||
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, z.shape[2], 48):
|
||||||
|
row = []
|
||||||
|
for j in range(0, z.shape[3], 48):
|
||||||
|
tile = z[:, :, i : i + 64, j : j + 64]
|
||||||
|
tile = self.post_quant_conv(tile)
|
||||||
|
decoded = self.decoder(tile)
|
||||||
|
row.append(decoded)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, 128)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, 128)
|
||||||
|
result_row.append(tile[:, :, :384, :384])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
dec = torch.cat(result_rows, dim=2)
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
|
|||||||
@@ -183,6 +183,22 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
self.vae.disable_slicing()
|
self.vae.disable_slicing()
|
||||||
|
|
||||||
|
def enable_vae_tiling(self):
|
||||||
|
r"""
|
||||||
|
Enable tiled VAE decoding.
|
||||||
|
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
||||||
|
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
||||||
|
"""
|
||||||
|
self.vae.enable_tiling()
|
||||||
|
|
||||||
|
def disable_vae_tiling(self):
|
||||||
|
r"""
|
||||||
|
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
||||||
|
computing decoding in one step.
|
||||||
|
"""
|
||||||
|
self.vae.disable_tiling()
|
||||||
|
|
||||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||||
r"""
|
r"""
|
||||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||||
|
|||||||
@@ -186,6 +186,22 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
self.vae.disable_slicing()
|
self.vae.disable_slicing()
|
||||||
|
|
||||||
|
def enable_vae_tiling(self):
|
||||||
|
r"""
|
||||||
|
Enable tiled VAE decoding.
|
||||||
|
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
||||||
|
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
||||||
|
"""
|
||||||
|
self.vae.enable_tiling()
|
||||||
|
|
||||||
|
def disable_vae_tiling(self):
|
||||||
|
r"""
|
||||||
|
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
||||||
|
computing decoding in one step.
|
||||||
|
"""
|
||||||
|
self.vae.disable_tiling()
|
||||||
|
|
||||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||||
r"""
|
r"""
|
||||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||||
|
|||||||
@@ -419,6 +419,40 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
# there is a small discrepancy at image borders vs. full batch decode
|
# there is a small discrepancy at image borders vs. full batch decode
|
||||||
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3
|
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3
|
||||||
|
|
||||||
|
def test_stable_diffusion_vae_tiling(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
|
||||||
|
# Test that tiled decode at 512x512 yields the same result as the non-tiled decode
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||||
|
|
||||||
|
# make sure tiled vae decode yields the same result
|
||||||
|
sd_pipe.enable_vae_tiling()
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||||
|
|
||||||
|
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_negative_prompt(self):
|
def test_stable_diffusion_negative_prompt(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
components = self.get_dummy_components()
|
components = self.get_dummy_components()
|
||||||
@@ -699,6 +733,58 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
|||||||
# There is a small discrepancy at the image borders vs. a fully batched version.
|
# There is a small discrepancy at the image borders vs. a fully batched version.
|
||||||
assert np.abs(image_sliced - image).max() < 1e-2
|
assert np.abs(image_sliced - image).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_vae_tiling(self):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||||
|
pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
|
||||||
|
pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
prompt = "a photograph of an astronaut riding a horse"
|
||||||
|
|
||||||
|
# enable vae tiling
|
||||||
|
pipe.enable_vae_tiling()
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
with torch.autocast(torch_device):
|
||||||
|
output_chunked = pipe(
|
||||||
|
[prompt],
|
||||||
|
width=640,
|
||||||
|
height=640,
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=7.5,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="numpy",
|
||||||
|
)
|
||||||
|
image_chunked = output_chunked.images
|
||||||
|
|
||||||
|
mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
# make sure that less than 4 GB is allocated
|
||||||
|
assert mem_bytes < 4e9
|
||||||
|
|
||||||
|
# disable vae tiling
|
||||||
|
pipe.disable_vae_tiling()
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
with torch.autocast(torch_device):
|
||||||
|
output = pipe(
|
||||||
|
[prompt],
|
||||||
|
width=640,
|
||||||
|
height=640,
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=7.5,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="numpy",
|
||||||
|
)
|
||||||
|
image = output.images
|
||||||
|
|
||||||
|
# make sure that more than 4 GB is allocated
|
||||||
|
mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
|
assert mem_bytes > 4e9
|
||||||
|
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_fp16_vs_autocast(self):
|
def test_stable_diffusion_fp16_vs_autocast(self):
|
||||||
# this test makes sure that the original model with autocast
|
# this test makes sure that the original model with autocast
|
||||||
# and the new model with fp16 yield the same result
|
# and the new model with fp16 yield the same result
|
||||||
|
|||||||
Reference in New Issue
Block a user