Compare commits

...

3 Commits

Author SHA1 Message Date
Patrick von Platen
c16ecac3c7 debug 2023-09-26 16:53:20 +02:00
Patrick von Platen
2fedbbf9af finish 2023-09-26 15:57:10 +02:00
Patrick von Platen
234600ce03 fix SDXL flax init 2023-09-26 15:54:11 +02:00
4 changed files with 49 additions and 4 deletions

View File

@@ -945,6 +945,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
down_block_res_samples = (sample,)
print("emb", emb.abs().sum())
print("sample", sample.abs().sum())
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D

View File

@@ -134,8 +134,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
added_cond_kwargs = None
if self.addition_embed_type == "text_time":
# TODO: how to get this from the config? It's no longer cross_attention_dim
text_embeds_dim = 1280
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
is_refiner = (
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
== self.config.projection_class_embeddings_input_dim
)
num_micro_conditions = 5 if is_refiner else 6
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
num_micro_conditions * self.config.addition_time_embed_dim
)
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
added_cond_kwargs = {
@@ -367,6 +377,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
if not isinstance(t_emb, jax._src.interpreters.partial_eval.DynamicJaxprTracer):
import torch; import numpy as np
print("t_emb", torch.from_numpy(np.asarray(t_emb)).abs().sum())
print("sample", torch.from_numpy(np.asarray(sample)).abs().sum())
# 3. down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:

View File

@@ -37,6 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False
DEBUG = True
class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
@@ -216,15 +217,21 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * params["scheduler"].init_noise_sigma
# Prepare scheduler state
scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)
import torch; import numpy as np
latents = latents * scheduler_state.init_noise_sigma
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
print("prompt_embeds", torch.from_numpy(np.asarray(prompt_embeds)).abs().sum())
print("text_embeds", torch.from_numpy(np.asarray(add_text_embeds)).abs().sum())
print("add_time_ids", add_time_ids)
# Denoising loop
def loop_body(step, args):
latents, scheduler_state = args
@@ -236,8 +243,12 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
print("latents_input 1", torch.from_numpy(np.asarray(latents_input)).abs().sum())
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
print("latents_input 2", torch.from_numpy(np.asarray(latents_input)).abs().sum())
# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
@@ -250,8 +261,12 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
print("noise_pred", torch.from_numpy(np.asarray(noise_pred)).abs().sum())
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
print("latents_input 3", torch.from_numpy(np.asarray(latents)).abs().sum())
return latents, scheduler_state
if DEBUG:

View File

@@ -818,6 +818,10 @@ class StableDiffusionXLPipeline(
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
print("prompt embeds", prompt_embeds.abs().sum())
print("text_embeds", add_text_embeds.abs().sum())
print("add_time_ids", add_time_ids)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -837,8 +841,12 @@ class StableDiffusionXLPipeline(
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
print("latent_model_input 1", latent_model_input.abs().sum())
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
print("latent_model_input 2", latent_model_input.abs().sum())
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.unet(
@@ -859,9 +867,13 @@ class StableDiffusionXLPipeline(
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
print("noise pred", noise_pred.abs().sum())
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
print("latent_model_input 3", latents.abs().sum())
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()