mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-22 20:34:51 +08:00
Compare commits
11 Commits
memory-opt
...
flax-fix-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a47d1f07d8 | ||
|
|
19b3de46ce | ||
|
|
74128b2adc | ||
|
|
44546ef516 | ||
|
|
da3311d2ff | ||
|
|
1f0117d3ee | ||
|
|
a2766393ef | ||
|
|
9d10981805 | ||
|
|
7753431398 | ||
|
|
6bf1983cb4 | ||
|
|
1eb9024e43 |
@@ -189,7 +189,7 @@ class FlaxModelMixin:
|
|||||||
```"""
|
```"""
|
||||||
return self._cast_floating_to(params, jnp.float16, mask)
|
return self._cast_floating_to(params, jnp.float16, mask)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
|
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
|
||||||
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||||||
flip_sin_to_cos: bool = True
|
flip_sin_to_cos: bool = True
|
||||||
freq_shift: int = 0
|
freq_shift: int = 0
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||||
|
|||||||
@@ -806,7 +806,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -475,6 +475,51 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||||||
model = pipeline_class(**init_kwargs, dtype=dtype)
|
model = pipeline_class(**init_kwargs, dtype=dtype)
|
||||||
return model, params
|
return model, params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_signature_keys(obj):
|
||||||
|
parameters = inspect.signature(obj.__init__).parameters
|
||||||
|
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||||
|
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||||
|
expected_modules = set(required_parameters.keys()) - set(["self"])
|
||||||
|
return expected_modules, optional_parameters
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self) -> Dict[str, Any]:
|
||||||
|
r"""
|
||||||
|
|
||||||
|
The `self.components` property can be useful to run different pipelines with the same weights and
|
||||||
|
configurations to not have to re-allocate memory.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from diffusers import (
|
||||||
|
... FlaxStableDiffusionPipeline,
|
||||||
|
... FlaxStableDiffusionImg2ImgPipeline,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
|
||||||
|
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
|
||||||
|
... )
|
||||||
|
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing all the modules needed to initialize the pipeline.
|
||||||
|
"""
|
||||||
|
expected_modules, optional_parameters = self._get_signature_keys(self)
|
||||||
|
components = {
|
||||||
|
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
if set(components.keys()) != expected_modules:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
||||||
|
f" {expected_modules} to be defined, but {components} are defined."
|
||||||
|
)
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numpy_to_pil(images):
|
def numpy_to_pil(images):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -764,7 +764,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionaly containing all the modules needed to initialize the pipeline.
|
A dictionary containing all the modules needed to initialize the pipeline.
|
||||||
"""
|
"""
|
||||||
expected_modules, optional_parameters = self._get_signature_keys(self)
|
expected_modules, optional_parameters = self._get_signature_keys(self)
|
||||||
components = {
|
components = {
|
||||||
|
|||||||
@@ -184,18 +184,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
prompt_ids: jnp.array,
|
prompt_ids: jnp.array,
|
||||||
params: Union[Dict, FrozenDict],
|
params: Union[Dict, FrozenDict],
|
||||||
prng_seed: jax.random.PRNGKey,
|
prng_seed: jax.random.KeyArray,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int,
|
||||||
height: Optional[int] = None,
|
height: int,
|
||||||
width: Optional[int] = None,
|
width: int,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float,
|
||||||
latents: Optional[jnp.array] = None,
|
latents: Optional[jnp.array] = None,
|
||||||
neg_prompt_ids: jnp.array = None,
|
neg_prompt_ids: Optional[jnp.array] = None,
|
||||||
):
|
):
|
||||||
# 0. Default height and width to unet
|
|
||||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
||||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
||||||
|
|
||||||
if height % 8 != 0 or width % 8 != 0:
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
|
|
||||||
@@ -281,15 +277,15 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
prompt_ids: jnp.array,
|
prompt_ids: jnp.array,
|
||||||
params: Union[Dict, FrozenDict],
|
params: Union[Dict, FrozenDict],
|
||||||
prng_seed: jax.random.PRNGKey,
|
prng_seed: jax.random.KeyArray,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||||
latents: jnp.array = None,
|
latents: jnp.array = None,
|
||||||
|
neg_prompt_ids: jnp.array = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
jit: bool = False,
|
jit: bool = False,
|
||||||
neg_prompt_ids: jnp.array = None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -41,6 +41,9 @@ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
@@ -106,6 +109,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
|
|
||||||
def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
|
def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
|
||||||
if not isinstance(prompt, (str, list)):
|
if not isinstance(prompt, (str, list)):
|
||||||
@@ -116,10 +120,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
|
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
image = [image]
|
image = [image]
|
||||||
processed_image = []
|
|
||||||
for img in image:
|
processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
|
||||||
processed_image.append(preprocess(img, self.dtype))
|
|
||||||
processed_image = jnp.array(processed_image).squeeze()
|
|
||||||
|
|
||||||
text_input = self.tokenizer(
|
text_input = self.tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
@@ -128,7 +130,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="np",
|
return_tensors="np",
|
||||||
)
|
)
|
||||||
return text_input.input_ids, processed_image
|
return text_input.input_ids, processed_images
|
||||||
|
|
||||||
def _get_has_nsfw_concepts(self, features, params):
|
def _get_has_nsfw_concepts(self, features, params):
|
||||||
has_nsfw_concepts = self.safety_checker(features, params)
|
has_nsfw_concepts = self.safety_checker(features, params)
|
||||||
@@ -164,12 +166,11 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
|
|
||||||
return images, has_nsfw_concepts
|
return images, has_nsfw_concepts
|
||||||
|
|
||||||
def get_timestep_start(self, num_inference_steps, strength, scheduler_state):
|
def get_timestep_start(self, num_inference_steps, strength):
|
||||||
# get the original timestep using init_timestep
|
# get the original timestep using init_timestep
|
||||||
offset = self.scheduler.config.get("steps_offset", 0)
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||||
init_timestep = int(num_inference_steps * strength) + offset
|
|
||||||
init_timestep = min(init_timestep, num_inference_steps)
|
t_start = max(num_inference_steps - init_timestep, 0)
|
||||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
|
||||||
|
|
||||||
return t_start
|
return t_start
|
||||||
|
|
||||||
@@ -178,13 +179,14 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
prompt_ids: jnp.array,
|
prompt_ids: jnp.array,
|
||||||
image: jnp.array,
|
image: jnp.array,
|
||||||
params: Union[Dict, FrozenDict],
|
params: Union[Dict, FrozenDict],
|
||||||
prng_seed: jax.random.PRNGKey,
|
prng_seed: jax.random.KeyArray,
|
||||||
strength: float = 0.8,
|
start_timestep: int,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int,
|
||||||
height: int = 512,
|
height: int,
|
||||||
width: int = 512,
|
width: int,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float,
|
||||||
debug: bool = False,
|
noise: Optional[jnp.array] = None,
|
||||||
|
neg_prompt_ids: Optional[jnp.array] = None,
|
||||||
):
|
):
|
||||||
if height % 8 != 0 or width % 8 != 0:
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
@@ -197,18 +199,32 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
batch_size = prompt_ids.shape[0]
|
batch_size = prompt_ids.shape[0]
|
||||||
|
|
||||||
max_length = prompt_ids.shape[-1]
|
max_length = prompt_ids.shape[-1]
|
||||||
|
|
||||||
|
if neg_prompt_ids is None:
|
||||||
uncond_input = self.tokenizer(
|
uncond_input = self.tokenizer(
|
||||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||||
)
|
).input_ids
|
||||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
|
else:
|
||||||
|
uncond_input = neg_prompt_ids
|
||||||
|
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||||
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
||||||
|
|
||||||
|
latents_shape = (
|
||||||
|
batch_size,
|
||||||
|
self.unet.in_channels,
|
||||||
|
height // self.vae_scale_factor,
|
||||||
|
width // self.vae_scale_factor,
|
||||||
|
)
|
||||||
|
if noise is None:
|
||||||
|
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||||
|
else:
|
||||||
|
if noise.shape != latents_shape:
|
||||||
|
raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}")
|
||||||
|
|
||||||
# Create init_latents
|
# Create init_latents
|
||||||
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
|
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
|
||||||
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
|
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
|
||||||
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
|
|
||||||
|
|
||||||
def loop_body(step, args):
|
def loop_body(step, args):
|
||||||
latents, scheduler_state = args
|
latents, scheduler_state = args
|
||||||
@@ -241,19 +257,19 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
|
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
|
||||||
)
|
)
|
||||||
|
|
||||||
t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state)
|
latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size)
|
||||||
latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size)
|
|
||||||
init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)
|
|
||||||
latents = init_latents
|
|
||||||
|
|
||||||
if debug:
|
latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep)
|
||||||
|
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
latents = latents * params["scheduler"].init_noise_sigma
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
# run with python for loop
|
# run with python for loop
|
||||||
for i in range(t_start, len(scheduler_state.timesteps)):
|
for i in range(start_timestep, num_inference_steps):
|
||||||
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
||||||
else:
|
else:
|
||||||
latents, _ = jax.lax.fori_loop(
|
latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||||
t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state)
|
|
||||||
)
|
|
||||||
|
|
||||||
# scale and decode the image latents with vae
|
# scale and decode the image latents with vae
|
||||||
latents = 1 / 0.18215 * latents
|
latents = 1 / 0.18215 * latents
|
||||||
@@ -268,14 +284,15 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
image: jnp.array,
|
image: jnp.array,
|
||||||
params: Union[Dict, FrozenDict],
|
params: Union[Dict, FrozenDict],
|
||||||
prng_seed: jax.random.KeyArray,
|
prng_seed: jax.random.KeyArray,
|
||||||
num_inference_steps: int = 50,
|
|
||||||
height: int = 512,
|
|
||||||
width: int = 512,
|
|
||||||
guidance_scale: float = 7.5,
|
|
||||||
strength: float = 0.8,
|
strength: float = 0.8,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||||
|
noise: jnp.array = None,
|
||||||
|
neg_prompt_ids: jnp.array = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
jit: bool = False,
|
jit: bool = False,
|
||||||
debug: bool = False,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@@ -287,12 +304,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
Array representing an image batch, that will be used as the starting point for the process.
|
Array representing an image batch, that will be used as the starting point for the process.
|
||||||
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
|
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
|
||||||
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
|
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
|
||||||
|
strength (`float`, *optional*, defaults to 0.8):
|
||||||
|
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||||
|
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||||
|
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||||
|
be maximum and the denoising process will run for the full number of iterations specified in
|
||||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
expense of slower inference.
|
expense of slower inference.
|
||||||
height (`int`, *optional*, defaults to 512):
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
The height in pixels of the generated image.
|
The height in pixels of the generated image.
|
||||||
width (`int`, *optional*, defaults to 512):
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
The width in pixels of the generated image.
|
The width in pixels of the generated image.
|
||||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
@@ -300,18 +322,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
usually at the expense of lower image quality.
|
usually at the expense of lower image quality.
|
||||||
strength (`float`, *optional*, defaults to 0.8):
|
noise (`jnp.array`, *optional*):
|
||||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
|
||||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
by sampling using the supplied random `generator`.
|
||||||
be maximum and the denoising process will run for the full number of iterations specified in
|
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||||
a plain tuple.
|
a plain tuple.
|
||||||
jit (`bool`, defaults to `False`):
|
jit (`bool`, defaults to `False`):
|
||||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||||
debug (`bool`, *optional*, defaults to `False`): Whether to make use of python forloop or lax.fori_loop
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||||
@@ -319,76 +340,109 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
|
# 0. Default height and width to unet
|
||||||
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
if isinstance(guidance_scale, float):
|
||||||
|
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||||
|
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||||
|
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
||||||
|
if len(prompt_ids.shape) > 2:
|
||||||
|
# Assume sharded
|
||||||
|
guidance_scale = guidance_scale[:, None]
|
||||||
|
|
||||||
|
start_timestep = self.get_timestep_start(num_inference_steps, strength)
|
||||||
|
|
||||||
if jit:
|
if jit:
|
||||||
image = _p_generate(
|
images = _p_generate(
|
||||||
self,
|
self,
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
image,
|
image,
|
||||||
params,
|
params,
|
||||||
prng_seed,
|
prng_seed,
|
||||||
strength,
|
start_timestep,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
guidance_scale,
|
guidance_scale,
|
||||||
debug,
|
noise,
|
||||||
|
neg_prompt_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = self._generate(
|
images = self._generate(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
image,
|
image,
|
||||||
params,
|
params,
|
||||||
prng_seed,
|
prng_seed,
|
||||||
strength,
|
start_timestep,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
guidance_scale,
|
guidance_scale,
|
||||||
debug,
|
noise,
|
||||||
|
neg_prompt_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.safety_checker is not None:
|
if self.safety_checker is not None:
|
||||||
safety_params = params["safety_checker"]
|
safety_params = params["safety_checker"]
|
||||||
image_uint8_casted = (image * 255).round().astype("uint8")
|
images_uint8_casted = (images * 255).round().astype("uint8")
|
||||||
num_devices, batch_size = image.shape[:2]
|
num_devices, batch_size = images.shape[:2]
|
||||||
|
|
||||||
image_uint8_casted = np.asarray(image_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||||
image_uint8_casted, has_nsfw_concept = self._run_safety_checker(image_uint8_casted, safety_params, jit)
|
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
||||||
image = np.asarray(image)
|
images = np.asarray(images)
|
||||||
|
|
||||||
# block images
|
# block images
|
||||||
if any(has_nsfw_concept):
|
if any(has_nsfw_concept):
|
||||||
for i, is_nsfw in enumerate(has_nsfw_concept):
|
for i, is_nsfw in enumerate(has_nsfw_concept):
|
||||||
if is_nsfw:
|
if is_nsfw:
|
||||||
image[i] = np.asarray(image_uint8_casted[i])
|
images[i] = np.asarray(images_uint8_casted[i])
|
||||||
|
|
||||||
image = image.reshape(num_devices, batch_size, height, width, 3)
|
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||||
else:
|
else:
|
||||||
|
images = np.asarray(images)
|
||||||
has_nsfw_concept = False
|
has_nsfw_concept = False
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (image, has_nsfw_concept)
|
return (images, has_nsfw_concept)
|
||||||
|
|
||||||
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||||
|
|
||||||
|
|
||||||
# TODO: maybe use a config dict instead of so many static argnums
|
# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation.
|
||||||
@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10))
|
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
||||||
|
@partial(
|
||||||
|
jax.pmap,
|
||||||
|
in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0),
|
||||||
|
static_broadcasted_argnums=(0, 5, 6, 7, 8),
|
||||||
|
)
|
||||||
def _p_generate(
|
def _p_generate(
|
||||||
pipe,
|
pipe,
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
image,
|
image,
|
||||||
params,
|
params,
|
||||||
prng_seed,
|
prng_seed,
|
||||||
strength,
|
start_timestep,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
guidance_scale,
|
guidance_scale,
|
||||||
debug,
|
noise,
|
||||||
|
neg_prompt_ids,
|
||||||
):
|
):
|
||||||
return pipe._generate(
|
return pipe._generate(
|
||||||
prompt_ids, image, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug
|
prompt_ids,
|
||||||
|
image,
|
||||||
|
params,
|
||||||
|
prng_seed,
|
||||||
|
start_timestep,
|
||||||
|
num_inference_steps,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
guidance_scale,
|
||||||
|
noise,
|
||||||
|
neg_prompt_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
|
|||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensor
|
# init input tensor
|
||||||
clip_input = jax.random.normal(rng, input_shape)
|
clip_input = jax.random.normal(rng, input_shape)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user