Compare commits

...

4 Commits

Author SHA1 Message Date
Suraj Patil
b2c9b5469a [img2img, inpainting] fix fp16 inference (#769)
* handle dtype in vae and image2image pipeline

* fix inpaint in fp16

* dtype should be handled in add_noise

* style

* address review comments

* add simple fast tests to check fp16

* fix test name

* put mask in fp16
2022-10-12 00:27:51 +02:00
Patrick von Platen
7a6cf8912c Release: v0.4.2 2022-10-12 00:13:11 +02:00
Pedro Cuenca
27455268fe mps: Alternative implementation for repeat_interleave (#766)
* mps: alt. implementation for repeat_interleave

* style

* Bump mps version of PyTorch in the documentation.

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Simplify: do not check for device.

* style

* Fix repeat dimensions:

- The unconditional embeddings are always created from a single prompt.
- I was shadowing the batch_size var.

* Split long lines as suggested by Suraj.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
2022-10-12 00:07:44 +02:00
Suraj Patil
2bdde4dd83 [schedulers] hanlde dtype in add_noise (#767)
* handle dtype in vae and image2image pipeline

* handle dtype in add noise

* don't modify vae and pipeline

* remove the if
2022-10-12 00:07:17 +02:00
12 changed files with 213 additions and 85 deletions

View File

@@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
- Mac computer with Apple silicon (M1/M2) hardware.
- macOS 12.3 or later.
- arm64 version of Python.
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later.
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
## Inference Pipeline

View File

@@ -211,7 +211,7 @@ install_requires = [
setup(
name="diffusers",
version="0.4.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.4.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -9,7 +9,7 @@ from .utils import (
)
__version__ = "0.4.1"
__version__ = "0.4.2"
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel

View File

@@ -337,12 +337,16 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
# make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device=device, dtype=self.parameters.dtype)
x = self.mean + self.std * sample
return x

View File

@@ -218,8 +218,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -256,8 +258,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch

View File

@@ -217,26 +217,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
@@ -297,6 +277,28 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
@@ -341,7 +343,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
if output_type == "pil":
image = self.numpy_to_pil(image)

View File

@@ -234,43 +234,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
init_image = init_image.to(self.device)
# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
@@ -335,6 +298,43 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502

View File

@@ -300,11 +300,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()

View File

@@ -294,11 +294,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()

View File

@@ -257,9 +257,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
deprecate(
"timesteps as indices",
@@ -273,7 +277,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
else:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

View File

@@ -400,11 +400,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()

View File

@@ -1005,6 +1005,124 @@ class PipelineFastTests(unittest.TestCase):
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_fp16(self):
"""Test that stable diffusion works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
assert image.shape == (1, 128, 128, 3)
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_img2img_fp16(self):
"""Test that stable diffusion img2img works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
init_image = self.dummy_image.to(torch_device)
# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe(
[prompt],
generator=generator,
num_inference_steps=2,
output_type="np",
init_image=init_image,
).images
assert image.shape == (1, 32, 32, 3)
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_inpaint_fp16(self):
"""Test that stable diffusion inpaint works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe(
[prompt],
generator=generator,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
).images
assert image.shape == (1, 32, 32, 3)
class PipelineTesterMixin(unittest.TestCase):
def tearDown(self):