Compare commits

...

19 Commits

Author SHA1 Message Date
Patrick von Platen
a4bc76ce75 make style 2023-03-01 16:20:52 +00:00
Patrick von Platen
a03d41f344 Apply suggestions from code review 2023-03-01 17:16:28 +01:00
Patrick von Platen
8e7b8c218e Apply suggestions from code review 2023-03-01 17:16:04 +01:00
Patrick von Platen
6adfedfabf Apply suggestions from code review 2023-03-01 17:15:49 +01:00
Patrick von Platen
94781b6e7d make style 2023-03-01 17:13:23 +01:00
Patrick von Platen
928c6d324f up 2023-03-01 17:12:28 +01:00
Patrick von Platen
541f27517d up 2023-03-01 17:07:11 +01:00
Patrick von Platen
20387d0319 up 2023-03-01 17:05:09 +01:00
Patrick von Platen
2a403a173e fix conflicts 2023-03-01 16:56:08 +01:00
Ilmari Heikkinen
307fd12567 tiled vae tests, remove tiling test from fast tests 2022-12-04 03:24:17 +08:00
Ilmari Heikkinen
2b0454d8ad tiled vae tests, use smaller test image 2022-12-04 03:05:53 +08:00
Ilmari Heikkinen
0a96a8184f tiled vae tests, use channels_last memory format 2022-12-04 02:59:39 +08:00
Ilmari Heikkinen
c99dbb6ff8 Merge branch 'sd-vae-tiled' of https://github.com/kig/diffusers into sd-vae-tiled 2022-12-04 02:34:03 +08:00
Ilmari Heikkinen
ac8b1c2eeb tiled vae docs, disable tiling for images that would have only one tile 2022-12-04 02:32:16 +08:00
Ilmari Heikkinen
4b6536d49a enable_vae_tiling API and tests 2022-12-04 02:19:11 +08:00
Ilmari Heikkinen
14215bf3fd Merge branch 'main' into sd-vae-tiled 2022-12-04 01:40:32 +08:00
Ilmari Heikkinen
626fb88a9d Merge branch 'main' into sd-vae-tiled 2022-11-30 16:26:38 +08:00
Ilmari Heikkinen
63d5661b47 vae tiling, fix formatting 2022-11-28 10:56:16 +08:00
Ilmari Heikkinen
49b61c8e5b Tiled VAE for high-res text2img and img2img 2022-11-27 19:02:24 +08:00
6 changed files with 277 additions and 17 deletions

View File

@@ -36,4 +36,6 @@ Available Checkpoints are:
- enable_vae_slicing
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- enable_vae_tiling
- disable_vae_tiling

View File

@@ -133,6 +133,33 @@ images = pipe([prompt] * 32).images
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
## Tiled VAE decode and encode for large images
Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image.
You want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example:
```Python
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "a beautiful landscape photograph"
pipe.enable_vae_tiling()
pipe.enable_xformers_memory_efficient_attention()
images = pipe([prompt], width=3840, height=2224).images
```
The output image will have some tile-to-tile tone variation from the tiles having separate decoders, but you shouldn't see sharp seams between the tiles. The tiling is turned off for images that are 512x512 or smaller.
<a name="sequential_offloading"></a>
## Offloading to CPU with accelerate for memory savings

View File

@@ -109,8 +109,40 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
the processing of larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
if self.use_tiling and (x.shape[-1] > 512 or x.shape[-2] > 512):
return self.tiled_encode(x, return_dict=return_dict)
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
@@ -121,6 +153,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_tiling and (z.shape[-1] > 64 or z.shape[-2] > 64):
return self.tiled_decode(z, return_dict=return_dict)
z = self.post_quant_conv(z)
dec = self.decoder(z)
@@ -129,22 +164,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec)
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
@@ -158,6 +177,100 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return DecoderOutput(sample=decoded)
def blend_v(self, a, b, blend_width):
for y in range(blend_width):
b[:, :, y, :] = a[:, :, -blend_width + y, :] * (1 - y / blend_width) + b[:, :, y, :] * (y / blend_width)
return b
def blend_h(self, a, b, blend_width):
for x in range(blend_width):
b[:, :, :, x] = a[:, :, :, -blend_width + x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
"""
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], 384):
row = []
for j in range(0, x.shape[3], 384):
tile = x[:, :, i : i + 512, j : j + 512]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, 16)
if j > 0:
tile = self.blend_h(row[j - 1], tile, 16)
result_row.append(tile[:, :, :48, :48])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""Decode a batch of images using a tiled decoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
`True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], 48):
row = []
for j in range(0, z.shape[3], 48):
tile = z[:, :, i : i + 64, j : j + 64]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, 128)
if j > 0:
tile = self.blend_h(row[j - 1], tile, 128)
result_row.append(tile[:, :, :384, :384])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -183,6 +183,22 @@ class AltDiffusionPipeline(DiffusionPipeline):
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding.
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

View File

@@ -186,6 +186,22 @@ class StableDiffusionPipeline(DiffusionPipeline):
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding.
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

View File

@@ -419,6 +419,40 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# there is a small discrepancy at image borders vs. full batch decode
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3
def test_stable_diffusion_vae_tiling(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
# Test that tiled decode at 512x512 yields the same result as the non-tiled decode
generator = torch.Generator(device=device).manual_seed(0)
output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
# make sure tiled vae decode yields the same result
sd_pipe.enable_vae_tiling()
generator = torch.Generator(device=device).manual_seed(0)
output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
def test_stable_diffusion_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -699,6 +733,58 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
# There is a small discrepancy at the image borders vs. a fully batched version.
assert np.abs(image_sliced - image).max() < 1e-2
def test_stable_diffusion_vae_tiling(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
prompt = "a photograph of an astronaut riding a horse"
# enable vae tiling
pipe.enable_vae_tiling()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt],
width=640,
height=640,
generator=generator,
guidance_scale=7.5,
num_inference_steps=2,
output_type="numpy",
)
image_chunked = output_chunked.images
mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 4 GB is allocated
assert mem_bytes < 4e9
# disable vae tiling
pipe.disable_vae_tiling()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt],
width=640,
height=640,
generator=generator,
guidance_scale=7.5,
num_inference_steps=2,
output_type="numpy",
)
image = output.images
# make sure that more than 4 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 4e9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
def test_stable_diffusion_fp16_vs_autocast(self):
# this test makes sure that the original model with autocast
# and the new model with fp16 yield the same result