mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c16ecac3c7 | ||
|
|
2fedbbf9af | ||
|
|
234600ce03 |
@@ -945,6 +945,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
print("emb", emb.abs().sum())
|
||||
print("sample", sample.abs().sum())
|
||||
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
# For t2i-adapter CrossAttnDownBlock2D
|
||||
|
||||
@@ -134,8 +134,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
added_cond_kwargs = None
|
||||
if self.addition_embed_type == "text_time":
|
||||
# TODO: how to get this from the config? It's no longer cross_attention_dim
|
||||
text_embeds_dim = 1280
|
||||
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
|
||||
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
|
||||
is_refiner = (
|
||||
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
|
||||
== self.config.projection_class_embeddings_input_dim
|
||||
)
|
||||
num_micro_conditions = 5 if is_refiner else 6
|
||||
|
||||
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
|
||||
num_micro_conditions * self.config.addition_time_embed_dim
|
||||
)
|
||||
|
||||
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
|
||||
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
|
||||
added_cond_kwargs = {
|
||||
@@ -367,6 +377,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
if not isinstance(t_emb, jax._src.interpreters.partial_eval.DynamicJaxprTracer):
|
||||
import torch; import numpy as np
|
||||
print("t_emb", torch.from_numpy(np.asarray(t_emb)).abs().sum())
|
||||
print("sample", torch.from_numpy(np.asarray(sample)).abs().sum())
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for down_block in self.down_blocks:
|
||||
|
||||
@@ -37,6 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||
DEBUG = False
|
||||
DEBUG = True
|
||||
|
||||
|
||||
class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||
@@ -216,15 +217,21 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
|
||||
# Prepare scheduler state
|
||||
scheduler_state = self.scheduler.set_timesteps(
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
||||
)
|
||||
|
||||
import torch; import numpy as np
|
||||
latents = latents * scheduler_state.init_noise_sigma
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
print("prompt_embeds", torch.from_numpy(np.asarray(prompt_embeds)).abs().sum())
|
||||
print("text_embeds", torch.from_numpy(np.asarray(add_text_embeds)).abs().sum())
|
||||
print("add_time_ids", add_time_ids)
|
||||
|
||||
# Denoising loop
|
||||
def loop_body(step, args):
|
||||
latents, scheduler_state = args
|
||||
@@ -236,8 +243,12 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
print("latents_input 1", torch.from_numpy(np.asarray(latents_input)).abs().sum())
|
||||
|
||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||
|
||||
print("latents_input 2", torch.from_numpy(np.asarray(latents_input)).abs().sum())
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
@@ -250,8 +261,12 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
print("noise_pred", torch.from_numpy(np.asarray(noise_pred)).abs().sum())
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||
|
||||
print("latents_input 3", torch.from_numpy(np.asarray(latents)).abs().sum())
|
||||
return latents, scheduler_state
|
||||
|
||||
if DEBUG:
|
||||
|
||||
@@ -818,6 +818,10 @@ class StableDiffusionXLPipeline(
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
print("prompt embeds", prompt_embeds.abs().sum())
|
||||
print("text_embeds", add_text_embeds.abs().sum())
|
||||
print("add_time_ids", add_time_ids)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
@@ -837,8 +841,12 @@ class StableDiffusionXLPipeline(
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
print("latent_model_input 1", latent_model_input.abs().sum())
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
print("latent_model_input 2", latent_model_input.abs().sum())
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
@@ -859,9 +867,13 @@ class StableDiffusionXLPipeline(
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
print("noise pred", noise_pred.abs().sum())
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
print("latent_model_input 3", latents.abs().sum())
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
Reference in New Issue
Block a user