Compare commits

...

11 Commits

Author SHA1 Message Date
Pedro Cuenca
a47d1f07d8 Apply doc-builder code style. 2022-12-29 13:42:09 +01:00
Simon Kirsten
19b3de46ce Remove unnecessary optional types in _generate 2022-12-29 13:28:31 +01:00
Simon Kirsten
74128b2adc Remove "static" comment 2022-12-29 13:28:31 +01:00
Simon Kirsten
44546ef516 Fix typo 2022-12-29 13:28:31 +01:00
Simon Kirsten
da3311d2ff latents.shape -> latents_shape 2022-12-29 13:28:31 +01:00
Simon Kirsten
1f0117d3ee Fix processed_images dimen 2022-12-29 13:28:31 +01:00
Simon Kirsten
a2766393ef Fix preprocess images 2022-12-29 13:28:31 +01:00
Simon Kirsten
9d10981805 Refactor strength to start_timestep 2022-12-29 13:28:31 +01:00
Simon Kirsten
7753431398 Flax: Fix PRNGKey type 2022-12-29 13:28:31 +01:00
Simon Kirsten
6bf1983cb4 Flax: Fix img2img and align with other pipeline 2022-12-29 13:28:31 +01:00
Simon Kirsten
1eb9024e43 Flax: Add components function 2022-12-29 13:28:31 +01:00
8 changed files with 179 additions and 84 deletions

View File

@@ -189,7 +189,7 @@ class FlaxModelMixin:
```"""
return self._cast_floating_to(params, jnp.float16, mask)
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod

View File

@@ -112,7 +112,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)

View File

@@ -806,7 +806,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype=self.dtype,
)
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)

View File

@@ -17,7 +17,7 @@
import importlib
import inspect
import os
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import numpy as np
@@ -475,6 +475,51 @@ class FlaxDiffusionPipeline(ConfigMixin):
model = pipeline_class(**init_kwargs, dtype=dtype)
return model, params
@staticmethod
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"])
return expected_modules, optional_parameters
@property
def components(self) -> Dict[str, Any]:
r"""
The `self.components` property can be useful to run different pipelines with the same weights and
configurations to not have to re-allocate memory.
Examples:
```py
>>> from diffusers import (
... FlaxStableDiffusionPipeline,
... FlaxStableDiffusionImg2ImgPipeline,
... )
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
... )
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
```
Returns:
A dictionary containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
if set(components.keys()) != expected_modules:
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components} are defined."
)
return components
@staticmethod
def numpy_to_pil(images):
"""

View File

@@ -764,7 +764,7 @@ class DiffusionPipeline(ConfigMixin):
```
Returns:
A dictionaly containing all the modules needed to initialize the pipeline.
A dictionary containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {

View File

@@ -184,18 +184,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: float = 7.5,
prng_seed: jax.random.KeyArray,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: jnp.array = None,
neg_prompt_ids: Optional[jnp.array] = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -281,15 +277,15 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
prng_seed: jax.random.KeyArray,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
neg_prompt_ids: jnp.array = None,
):
r"""
Function invoked when calling the pipeline for generation.

View File

@@ -14,7 +14,7 @@
import warnings
from functools import partial
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
import numpy as np
@@ -41,6 +41,9 @@ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
r"""
@@ -106,6 +109,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
if not isinstance(prompt, (str, list)):
@@ -116,10 +120,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
if isinstance(image, Image.Image):
image = [image]
processed_image = []
for img in image:
processed_image.append(preprocess(img, self.dtype))
processed_image = jnp.array(processed_image).squeeze()
processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
text_input = self.tokenizer(
prompt,
@@ -128,7 +130,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
truncation=True,
return_tensors="np",
)
return text_input.input_ids, processed_image
return text_input.input_ids, processed_images
def _get_has_nsfw_concepts(self, features, params):
has_nsfw_concepts = self.safety_checker(features, params)
@@ -164,12 +166,11 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
return images, has_nsfw_concepts
def get_timestep_start(self, num_inference_steps, strength, scheduler_state):
def get_timestep_start(self, num_inference_steps, strength):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
return t_start
@@ -178,13 +179,14 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
strength: float = 0.8,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
debug: bool = False,
prng_seed: jax.random.KeyArray,
start_timestep: int,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
noise: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -197,18 +199,32 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
if neg_prompt_ids is None:
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids
else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (
batch_size,
self.unet.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if noise is None:
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
if noise.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}")
# Create init_latents
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
init_latents = 0.18215 * init_latents
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
def loop_body(step, args):
latents, scheduler_state = args
@@ -241,19 +257,19 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
)
t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state)
latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size)
init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)
latents = init_latents
latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size)
if debug:
latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * params["scheduler"].init_noise_sigma
if DEBUG:
# run with python for loop
for i in range(t_start, len(scheduler_state.timesteps)):
for i in range(start_timestep, num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
else:
latents, _ = jax.lax.fori_loop(
t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state)
)
latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
@@ -268,14 +284,15 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
strength: float = 0.8,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
noise: jnp.array = None,
neg_prompt_ids: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -287,12 +304,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
Array representing an image batch, that will be used as the starting point for the process.
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
height (`int`, *optional*, defaults to 512):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
@@ -300,18 +322,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
noise (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
by sampling using the supplied random `generator`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
debug (`bool`, *optional*, defaults to `False`): Whether to make use of python forloop or lax.fori_loop
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
@@ -319,76 +340,109 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(guidance_scale, float):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
if len(prompt_ids.shape) > 2:
# Assume sharded
guidance_scale = guidance_scale[:, None]
start_timestep = self.get_timestep_start(num_inference_steps, strength)
if jit:
image = _p_generate(
images = _p_generate(
self,
prompt_ids,
image,
params,
prng_seed,
strength,
start_timestep,
num_inference_steps,
height,
width,
guidance_scale,
debug,
noise,
neg_prompt_ids,
)
else:
image = self._generate(
images = self._generate(
prompt_ids,
image,
params,
prng_seed,
strength,
start_timestep,
num_inference_steps,
height,
width,
guidance_scale,
debug,
noise,
neg_prompt_ids,
)
if self.safety_checker is not None:
safety_params = params["safety_checker"]
image_uint8_casted = (image * 255).round().astype("uint8")
num_devices, batch_size = image.shape[:2]
images_uint8_casted = (images * 255).round().astype("uint8")
num_devices, batch_size = images.shape[:2]
image_uint8_casted = np.asarray(image_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
image_uint8_casted, has_nsfw_concept = self._run_safety_checker(image_uint8_casted, safety_params, jit)
image = np.asarray(image)
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
images = np.asarray(images)
# block images
if any(has_nsfw_concept):
for i, is_nsfw in enumerate(has_nsfw_concept):
if is_nsfw:
image[i] = np.asarray(image_uint8_casted[i])
images[i] = np.asarray(images_uint8_casted[i])
image = image.reshape(num_devices, batch_size, height, width, 3)
images = images.reshape(num_devices, batch_size, height, width, 3)
else:
images = np.asarray(images)
has_nsfw_concept = False
if not return_dict:
return (image, has_nsfw_concept)
return (images, has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10))
# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation.
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@partial(
jax.pmap,
in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0),
static_broadcasted_argnums=(0, 5, 6, 7, 8),
)
def _p_generate(
pipe,
prompt_ids,
image,
params,
prng_seed,
strength,
start_timestep,
num_inference_steps,
height,
width,
guidance_scale,
debug,
noise,
neg_prompt_ids,
):
return pipe._generate(
prompt_ids, image, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug
prompt_ids,
image,
params,
prng_seed,
start_timestep,
num_inference_steps,
height,
width,
guidance_scale,
noise,
neg_prompt_ids,
)

View File

@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
clip_input = jax.random.normal(rng, input_shape)