mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
3 Commits
v0.27.1-pa
...
debug
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c16ecac3c7 | ||
|
|
2fedbbf9af | ||
|
|
234600ce03 |
@@ -945,6 +945,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
||||||
|
|
||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
|
print("emb", emb.abs().sum())
|
||||||
|
print("sample", sample.abs().sum())
|
||||||
|
|
||||||
for downsample_block in self.down_blocks:
|
for downsample_block in self.down_blocks:
|
||||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
# For t2i-adapter CrossAttnDownBlock2D
|
# For t2i-adapter CrossAttnDownBlock2D
|
||||||
|
|||||||
@@ -134,8 +134,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if self.addition_embed_type == "text_time":
|
if self.addition_embed_type == "text_time":
|
||||||
# TODO: how to get this from the config? It's no longer cross_attention_dim
|
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
|
||||||
text_embeds_dim = 1280
|
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
|
||||||
|
is_refiner = (
|
||||||
|
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
|
||||||
|
== self.config.projection_class_embeddings_input_dim
|
||||||
|
)
|
||||||
|
num_micro_conditions = 5 if is_refiner else 6
|
||||||
|
|
||||||
|
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
|
||||||
|
num_micro_conditions * self.config.addition_time_embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
|
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
|
||||||
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
|
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
@@ -367,6 +377,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||||||
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
if not isinstance(t_emb, jax._src.interpreters.partial_eval.DynamicJaxprTracer):
|
||||||
|
import torch; import numpy as np
|
||||||
|
print("t_emb", torch.from_numpy(np.asarray(t_emb)).abs().sum())
|
||||||
|
print("sample", torch.from_numpy(np.asarray(sample)).abs().sum())
|
||||||
|
|
||||||
# 3. down
|
# 3. down
|
||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
for down_block in self.down_blocks:
|
for down_block in self.down_blocks:
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
DEBUG = True
|
||||||
|
|
||||||
|
|
||||||
class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||||
@@ -216,15 +217,21 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
|||||||
if latents.shape != latents_shape:
|
if latents.shape != latents_shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
latents = latents * params["scheduler"].init_noise_sigma
|
|
||||||
|
|
||||||
# Prepare scheduler state
|
# Prepare scheduler state
|
||||||
scheduler_state = self.scheduler.set_timesteps(
|
scheduler_state = self.scheduler.set_timesteps(
|
||||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import torch; import numpy as np
|
||||||
|
latents = latents * scheduler_state.init_noise_sigma
|
||||||
|
|
||||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||||
|
|
||||||
|
print("prompt_embeds", torch.from_numpy(np.asarray(prompt_embeds)).abs().sum())
|
||||||
|
print("text_embeds", torch.from_numpy(np.asarray(add_text_embeds)).abs().sum())
|
||||||
|
print("add_time_ids", add_time_ids)
|
||||||
|
|
||||||
# Denoising loop
|
# Denoising loop
|
||||||
def loop_body(step, args):
|
def loop_body(step, args):
|
||||||
latents, scheduler_state = args
|
latents, scheduler_state = args
|
||||||
@@ -236,8 +243,12 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
|||||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||||
|
|
||||||
|
print("latents_input 1", torch.from_numpy(np.asarray(latents_input)).abs().sum())
|
||||||
|
|
||||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||||
|
|
||||||
|
print("latents_input 2", torch.from_numpy(np.asarray(latents_input)).abs().sum())
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet.apply(
|
noise_pred = self.unet.apply(
|
||||||
{"params": params["unet"]},
|
{"params": params["unet"]},
|
||||||
@@ -250,8 +261,12 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
|||||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
print("noise_pred", torch.from_numpy(np.asarray(noise_pred)).abs().sum())
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||||
|
|
||||||
|
print("latents_input 3", torch.from_numpy(np.asarray(latents)).abs().sum())
|
||||||
return latents, scheduler_state
|
return latents, scheduler_state
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
|
|||||||
@@ -818,6 +818,10 @@ class StableDiffusionXLPipeline(
|
|||||||
add_text_embeds = add_text_embeds.to(device)
|
add_text_embeds = add_text_embeds.to(device)
|
||||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||||
|
|
||||||
|
print("prompt embeds", prompt_embeds.abs().sum())
|
||||||
|
print("text_embeds", add_text_embeds.abs().sum())
|
||||||
|
print("add_time_ids", add_time_ids)
|
||||||
|
|
||||||
# 8. Denoising loop
|
# 8. Denoising loop
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
|
||||||
@@ -837,8 +841,12 @@ class StableDiffusionXLPipeline(
|
|||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
|
print("latent_model_input 1", latent_model_input.abs().sum())
|
||||||
|
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
print("latent_model_input 2", latent_model_input.abs().sum())
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
@@ -859,9 +867,13 @@ class StableDiffusionXLPipeline(
|
|||||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
|
print("noise pred", noise_pred.abs().sum())
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
|
print("latent_model_input 3", latents.abs().sum())
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
|||||||
Reference in New Issue
Block a user