mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 23:14:37 +08:00
Compare commits
4 Commits
torch-regr
...
v0.4.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2c9b5469a | ||
|
|
7a6cf8912c | ||
|
|
27455268fe | ||
|
|
2bdde4dd83 |
@@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
|
|||||||
- Mac computer with Apple silicon (M1/M2) hardware.
|
- Mac computer with Apple silicon (M1/M2) hardware.
|
||||||
- macOS 12.3 or later.
|
- macOS 12.3 or later.
|
||||||
- arm64 version of Python.
|
- arm64 version of Python.
|
||||||
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later.
|
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
|
||||||
|
|
||||||
## Inference Pipeline
|
## Inference Pipeline
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -211,7 +211,7 @@ install_requires = [
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.4.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.4.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="Diffusers",
|
description="Diffusers",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.4.1"
|
__version__ = "0.4.2"
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .onnx_utils import OnnxRuntimeModel
|
from .onnx_utils import OnnxRuntimeModel
|
||||||
|
|||||||
@@ -337,12 +337,16 @@ class DiagonalGaussianDistribution(object):
|
|||||||
self.std = torch.exp(0.5 * self.logvar)
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
self.var = torch.exp(self.logvar)
|
self.var = torch.exp(self.logvar)
|
||||||
if self.deterministic:
|
if self.deterministic:
|
||||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
self.var = self.std = torch.zeros_like(
|
||||||
|
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
||||||
|
)
|
||||||
|
|
||||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
||||||
device = self.parameters.device
|
device = self.parameters.device
|
||||||
sample_device = "cpu" if device.type == "mps" else device
|
sample_device = "cpu" if device.type == "mps" else device
|
||||||
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
|
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
|
||||||
|
# make sure sample is on the same device as the parameters and has same dtype
|
||||||
|
sample = sample.to(device=device, dtype=self.parameters.dtype)
|
||||||
x = self.mean + self.std * sample
|
x = self.mean + self.std * sample
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -218,8 +218,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
bs_embed, seq_len, _ = text_embeddings.shape
|
||||||
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||||
|
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
@@ -256,8 +258,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||||
|
|
||||||
# duplicate unconditional embeddings for each generation per prompt
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||||
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
seq_len = uncond_embeddings.shape[1]
|
||||||
|
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||||
|
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
|
|||||||
@@ -217,26 +217,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
if isinstance(init_image, PIL.Image.Image):
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
init_image = preprocess(init_image)
|
init_image = preprocess(init_image)
|
||||||
|
|
||||||
# encode the init image into latents and scale the latents
|
|
||||||
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
|
|
||||||
init_latents = init_latent_dist.sample(generator=generator)
|
|
||||||
init_latents = 0.18215 * init_latents
|
|
||||||
|
|
||||||
# expand init_latents for batch_size
|
|
||||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
|
||||||
|
|
||||||
# get the original timestep using init_timestep
|
|
||||||
offset = self.scheduler.config.get("steps_offset", 0)
|
|
||||||
init_timestep = int(num_inference_steps * strength) + offset
|
|
||||||
init_timestep = min(init_timestep, num_inference_steps)
|
|
||||||
|
|
||||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
|
||||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
|
||||||
|
|
||||||
# add noise to latents using the timesteps
|
|
||||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
|
||||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
|
||||||
|
|
||||||
# get prompt text embeddings
|
# get prompt text embeddings
|
||||||
text_inputs = self.tokenizer(
|
text_inputs = self.tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
@@ -297,6 +277,28 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
# to avoid doing two forward passes
|
# to avoid doing two forward passes
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||||
|
|
||||||
|
# encode the init image into latents and scale the latents
|
||||||
|
latents_dtype = text_embeddings.dtype
|
||||||
|
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||||
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
|
init_latents = init_latent_dist.sample(generator=generator)
|
||||||
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
|
# expand init_latents for batch_size
|
||||||
|
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
|
# get the original timestep using init_timestep
|
||||||
|
offset = self.scheduler.config.get("steps_offset", 0)
|
||||||
|
init_timestep = int(num_inference_steps * strength) + offset
|
||||||
|
init_timestep = min(init_timestep, num_inference_steps)
|
||||||
|
|
||||||
|
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||||
|
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||||
|
|
||||||
|
# add noise to latents using the timesteps
|
||||||
|
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||||
|
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||||
|
|
||||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
@@ -341,7 +343,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
|
||||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
image, has_nsfw_concept = self.safety_checker(
|
||||||
|
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
if output_type == "pil":
|
if output_type == "pil":
|
||||||
image = self.numpy_to_pil(image)
|
image = self.numpy_to_pil(image)
|
||||||
|
|||||||
@@ -234,43 +234,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||||||
# set timesteps
|
# set timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
# preprocess image
|
|
||||||
if not isinstance(init_image, torch.FloatTensor):
|
|
||||||
init_image = preprocess_image(init_image)
|
|
||||||
init_image = init_image.to(self.device)
|
|
||||||
|
|
||||||
# encode the init image into latents and scale the latents
|
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
|
||||||
init_latents = init_latent_dist.sample(generator=generator)
|
|
||||||
|
|
||||||
init_latents = 0.18215 * init_latents
|
|
||||||
|
|
||||||
# Expand init_latents for batch_size and num_images_per_prompt
|
|
||||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
|
||||||
init_latents_orig = init_latents
|
|
||||||
|
|
||||||
# preprocess mask
|
|
||||||
if not isinstance(mask_image, torch.FloatTensor):
|
|
||||||
mask_image = preprocess_mask(mask_image)
|
|
||||||
mask_image = mask_image.to(self.device)
|
|
||||||
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
|
|
||||||
|
|
||||||
# check sizes
|
|
||||||
if not mask.shape == init_latents.shape:
|
|
||||||
raise ValueError("The mask and init_image should be the same size!")
|
|
||||||
|
|
||||||
# get the original timestep using init_timestep
|
|
||||||
offset = self.scheduler.config.get("steps_offset", 0)
|
|
||||||
init_timestep = int(num_inference_steps * strength) + offset
|
|
||||||
init_timestep = min(init_timestep, num_inference_steps)
|
|
||||||
|
|
||||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
|
||||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
|
||||||
|
|
||||||
# add noise to latents using the timesteps
|
|
||||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
|
||||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
|
||||||
|
|
||||||
# get prompt text embeddings
|
# get prompt text embeddings
|
||||||
text_inputs = self.tokenizer(
|
text_inputs = self.tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
@@ -335,6 +298,43 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||||||
# to avoid doing two forward passes
|
# to avoid doing two forward passes
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||||
|
|
||||||
|
# preprocess image
|
||||||
|
if not isinstance(init_image, torch.FloatTensor):
|
||||||
|
init_image = preprocess_image(init_image)
|
||||||
|
|
||||||
|
# encode the init image into latents and scale the latents
|
||||||
|
latents_dtype = text_embeddings.dtype
|
||||||
|
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||||
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
|
init_latents = init_latent_dist.sample(generator=generator)
|
||||||
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
|
# Expand init_latents for batch_size and num_images_per_prompt
|
||||||
|
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||||
|
init_latents_orig = init_latents
|
||||||
|
|
||||||
|
# preprocess mask
|
||||||
|
if not isinstance(mask_image, torch.FloatTensor):
|
||||||
|
mask_image = preprocess_mask(mask_image)
|
||||||
|
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
|
||||||
|
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
|
||||||
|
|
||||||
|
# check sizes
|
||||||
|
if not mask.shape == init_latents.shape:
|
||||||
|
raise ValueError("The mask and init_image should be the same size!")
|
||||||
|
|
||||||
|
# get the original timestep using init_timestep
|
||||||
|
offset = self.scheduler.config.get("steps_offset", 0)
|
||||||
|
init_timestep = int(num_inference_steps * strength) + offset
|
||||||
|
init_timestep = min(init_timestep, num_inference_steps)
|
||||||
|
|
||||||
|
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||||
|
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||||
|
|
||||||
|
# add noise to latents using the timesteps
|
||||||
|
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||||
|
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||||
|
|
||||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
|
|||||||
@@ -300,11 +300,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
if self.alphas_cumprod.device != original_samples.device:
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
if timesteps.device != original_samples.device:
|
|
||||||
timesteps = timesteps.to(original_samples.device)
|
|
||||||
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
|
|||||||
@@ -294,11 +294,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
if self.alphas_cumprod.device != original_samples.device:
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
if timesteps.device != original_samples.device:
|
|
||||||
timesteps = timesteps.to(original_samples.device)
|
|
||||||
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
|
|||||||
@@ -257,9 +257,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.FloatTensor,
|
timesteps: torch.FloatTensor,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
sigmas = self.sigmas.to(original_samples.device)
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
|
self.timesteps = self.timesteps.to(original_samples.device)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
|
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
|
||||||
deprecate(
|
deprecate(
|
||||||
"timesteps as indices",
|
"timesteps as indices",
|
||||||
@@ -273,7 +277,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
else:
|
else:
|
||||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = self.sigmas[step_indices].flatten()
|
||||||
while len(sigma.shape) < len(original_samples.shape):
|
while len(sigma.shape) < len(original_samples.shape):
|
||||||
sigma = sigma.unsqueeze(-1)
|
sigma = sigma.unsqueeze(-1)
|
||||||
|
|
||||||
|
|||||||
@@ -400,11 +400,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.alphas_cumprod.device != original_samples.device:
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
if timesteps.device != original_samples.device:
|
|
||||||
timesteps = timesteps.to(original_samples.device)
|
|
||||||
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
|
|||||||
@@ -1005,6 +1005,124 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
|
|
||||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
|
||||||
|
def test_stable_diffusion_fp16(self):
|
||||||
|
"""Test that stable diffusion works with fp16"""
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# put models in fp16
|
||||||
|
unet = unet.half()
|
||||||
|
vae = vae.half()
|
||||||
|
bert = bert.half()
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
||||||
|
|
||||||
|
assert image.shape == (1, 128, 128, 3)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
|
||||||
|
def test_stable_diffusion_img2img_fp16(self):
|
||||||
|
"""Test that stable diffusion img2img works with fp16"""
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
init_image = self.dummy_image.to(torch_device)
|
||||||
|
|
||||||
|
# put models in fp16
|
||||||
|
unet = unet.half()
|
||||||
|
vae = vae.half()
|
||||||
|
bert = bert.half()
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
image = sd_pipe(
|
||||||
|
[prompt],
|
||||||
|
generator=generator,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
|
||||||
|
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
|
||||||
|
def test_stable_diffusion_inpaint_fp16(self):
|
||||||
|
"""Test that stable diffusion inpaint works with fp16"""
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||||
|
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||||
|
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||||
|
|
||||||
|
# put models in fp16
|
||||||
|
unet = unet.half()
|
||||||
|
vae = vae.half()
|
||||||
|
bert = bert.half()
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionInpaintPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
image = sd_pipe(
|
||||||
|
[prompt],
|
||||||
|
generator=generator,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
mask_image=mask_image,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
|
||||||
|
|
||||||
class PipelineTesterMixin(unittest.TestCase):
|
class PipelineTesterMixin(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user