mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-17 18:04:25 +08:00
Compare commits
4 Commits
flux2-fix
...
cache-docs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d76b744ac3 | ||
|
|
b26867b628 | ||
|
|
e3f441648c | ||
|
|
c6cfc5ce1d |
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
|||||||
|
|
||||||
[[autodoc]] apply_faster_cache
|
[[autodoc]] apply_faster_cache
|
||||||
|
|
||||||
### FirstBlockCacheConfig
|
## FirstBlockCacheConfig
|
||||||
|
|
||||||
[[autodoc]] FirstBlockCacheConfig
|
[[autodoc]] FirstBlockCacheConfig
|
||||||
|
|
||||||
|
|||||||
@@ -66,4 +66,8 @@ config = FasterCacheConfig(
|
|||||||
tensor_format="BFCHW",
|
tensor_format="BFCHW",
|
||||||
)
|
)
|
||||||
pipeline.transformer.enable_cache(config)
|
pipeline.transformer.enable_cache(config)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## FirstBlockCache
|
||||||
|
|
||||||
|
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler to implement generically for a wide range of models and has been integrated first for experimental purposes.
|
||||||
@@ -41,9 +41,11 @@ class CacheMixin:
|
|||||||
Enable caching techniques on the model.
|
Enable caching techniques on the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
|
||||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||||
|
- [`~hooks.FasterCacheConfig`]
|
||||||
|
- [`~hooks.FirstBlockCacheConfig`]
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
|||||||
@@ -69,10 +69,7 @@ class TimestepEmbedder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, t):
|
def forward(self, t):
|
||||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||||
weight_dtype = self.mlp[0].weight.dtype
|
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||||
if weight_dtype.is_floating_point:
|
|
||||||
t_freq = t_freq.to(weight_dtype)
|
|
||||||
t_emb = self.mlp(t_freq)
|
|
||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
@@ -129,10 +126,6 @@ class ZSingleStreamAttnProcessor:
|
|||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
query, key = query.to(dtype), key.to(dtype)
|
query, key = query.to(dtype), key.to(dtype)
|
||||||
|
|
||||||
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
|
||||||
if attention_mask is not None and attention_mask.ndim == 2:
|
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
# Compute joint attention
|
# Compute joint attention
|
||||||
hidden_states = dispatch_attention_fn(
|
hidden_states = dispatch_attention_fn(
|
||||||
query,
|
query,
|
||||||
@@ -313,10 +306,6 @@ class RopeEmbedder:
|
|||||||
if self.freqs_cis is None:
|
if self.freqs_cis is None:
|
||||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||||
else:
|
|
||||||
# Ensure freqs_cis are on the same device as ids
|
|
||||||
if self.freqs_cis[0].device != device:
|
|
||||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for i in range(len(self.axes_dims)):
|
for i in range(len(self.axes_dims)):
|
||||||
@@ -328,7 +317,6 @@ class RopeEmbedder:
|
|||||||
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||||
_supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["ZImageTransformerBlock"]
|
_no_split_modules = ["ZImageTransformerBlock"]
|
||||||
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
|
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -565,6 +553,8 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
|||||||
t = t * self.t_scale
|
t = t * self.t_scale
|
||||||
t = self.t_embedder(t)
|
t = self.t_embedder(t)
|
||||||
|
|
||||||
|
adaln_input = t
|
||||||
|
|
||||||
(
|
(
|
||||||
x,
|
x,
|
||||||
cap_feats,
|
cap_feats,
|
||||||
@@ -582,9 +572,6 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
|||||||
|
|
||||||
x = torch.cat(x, dim=0)
|
x = torch.cat(x, dim=0)
|
||||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||||
|
|
||||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
|
||||||
adaln_input = t.type_as(x)
|
|
||||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||||
x = list(x.split(x_item_seqlens, dim=0))
|
x = list(x.split(x_item_seqlens, dim=0))
|
||||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||||
|
|||||||
@@ -861,6 +861,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
|||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
image = latents
|
image = latents
|
||||||
else:
|
else:
|
||||||
|
torch.save({"pred": latents}, "pred_d.pt")
|
||||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||||
|
|
||||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||||
|
|||||||
@@ -165,16 +165,21 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
max_sequence_length: int = 512,
|
max_sequence_length: int = 512,
|
||||||
|
lora_scale: Optional[float] = None,
|
||||||
):
|
):
|
||||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
prompt_embeds = self._encode_prompt(
|
prompt_embeds = self._encode_prompt(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
device=device,
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
)
|
)
|
||||||
@@ -188,6 +193,8 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
negative_prompt_embeds = self._encode_prompt(
|
negative_prompt_embeds = self._encode_prompt(
|
||||||
prompt=negative_prompt,
|
prompt=negative_prompt,
|
||||||
device=device,
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
prompt_embeds=negative_prompt_embeds,
|
prompt_embeds=negative_prompt_embeds,
|
||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
)
|
)
|
||||||
@@ -199,9 +206,12 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
max_sequence_length: int = 512,
|
max_sequence_length: int = 512,
|
||||||
) -> List[torch.FloatTensor]:
|
) -> List[torch.FloatTensor]:
|
||||||
|
assert num_images_per_prompt == 1
|
||||||
device = device or self._execution_device
|
device = device or self._execution_device
|
||||||
|
|
||||||
if prompt_embeds is not None:
|
if prompt_embeds is not None:
|
||||||
@@ -407,6 +417,8 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
f"Please adjust the width to a multiple of {vae_scale}."
|
f"Please adjust the width to a multiple of {vae_scale}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert self.dtype == torch.bfloat16
|
||||||
|
dtype = self.dtype
|
||||||
device = self._execution_device
|
device = self._execution_device
|
||||||
|
|
||||||
self._guidance_scale = guidance_scale
|
self._guidance_scale = guidance_scale
|
||||||
@@ -422,6 +434,10 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
else:
|
else:
|
||||||
batch_size = len(prompt_embeds)
|
batch_size = len(prompt_embeds)
|
||||||
|
|
||||||
|
lora_scale = (
|
||||||
|
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||||
if prompt_embeds is not None and prompt is None:
|
if prompt_embeds is not None and prompt is None:
|
||||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||||
@@ -439,8 +455,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
|
lora_scale=lora_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Prepare latent variables
|
# 4. Prepare latent variables
|
||||||
@@ -456,14 +475,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
generator,
|
generator,
|
||||||
latents,
|
latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Repeat prompt_embeds for num_images_per_prompt
|
|
||||||
if num_images_per_prompt > 1:
|
|
||||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
|
||||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
|
||||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
|
||||||
|
|
||||||
actual_batch_size = batch_size * num_images_per_prompt
|
|
||||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||||
|
|
||||||
# 5. Prepare timesteps
|
# 5. Prepare timesteps
|
||||||
@@ -512,12 +523,12 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||||
|
|
||||||
if apply_cfg:
|
if apply_cfg:
|
||||||
latents_typed = latents.to(self.transformer.dtype)
|
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
|
||||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||||
timestep_model_input = timestep.repeat(2)
|
timestep_model_input = timestep.repeat(2)
|
||||||
else:
|
else:
|
||||||
latent_model_input = latents.to(self.transformer.dtype)
|
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
|
||||||
prompt_embeds_model_input = prompt_embeds
|
prompt_embeds_model_input = prompt_embeds
|
||||||
timestep_model_input = timestep
|
timestep_model_input = timestep
|
||||||
|
|
||||||
@@ -532,11 +543,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
|
|
||||||
if apply_cfg:
|
if apply_cfg:
|
||||||
# Perform CFG
|
# Perform CFG
|
||||||
pos_out = model_out_list[:actual_batch_size]
|
pos_out = model_out_list[:batch_size]
|
||||||
neg_out = model_out_list[actual_batch_size:]
|
neg_out = model_out_list[batch_size:]
|
||||||
|
|
||||||
noise_pred = []
|
noise_pred = []
|
||||||
for j in range(actual_batch_size):
|
for j in range(batch_size):
|
||||||
pos = pos_out[j].float()
|
pos = pos_out[j].float()
|
||||||
neg = neg_out[j].float()
|
neg = neg_out[j].float()
|
||||||
|
|
||||||
@@ -577,11 +588,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
|
||||||
|
latents = latents.to(dtype)
|
||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
image = latents
|
image = latents
|
||||||
|
|
||||||
else:
|
else:
|
||||||
latents = latents.to(self.vae.dtype)
|
|
||||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
|
|
||||||
image = self.vae.decode(latents, return_dict=False)[0]
|
image = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
|||||||
@@ -429,22 +429,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Find the index for a given timestep in the schedule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The timestep for which to find the index.
|
|
||||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
|
||||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`int`:
|
|
||||||
The index of the timestep in the schedule.
|
|
||||||
"""
|
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@@ -467,10 +452,6 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The current timestep for which to initialize the step index.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
|
|||||||
@@ -401,17 +401,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
"""
|
|
||||||
Convert sigma values to alpha_t and sigma_t values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sigma (`torch.Tensor`):
|
|
||||||
The sigma value(s) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
|
||||||
A tuple containing (alpha_t, sigma_t) values.
|
|
||||||
"""
|
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@@ -819,22 +808,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
raise NotImplementedError("only support log-rho multistep deis now")
|
raise NotImplementedError("only support log-rho multistep deis now")
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Find the index for a given timestep in the schedule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The timestep for which to find the index.
|
|
||||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
|
||||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`int`:
|
|
||||||
The index of the timestep in the schedule.
|
|
||||||
"""
|
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@@ -857,10 +831,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The current timestep for which to initialize the step index.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
@@ -957,21 +927,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
original_samples (`torch.Tensor`):
|
|
||||||
The original samples without noise.
|
|
||||||
noise (`torch.Tensor`):
|
|
||||||
The noise to add to the samples.
|
|
||||||
timesteps (`torch.IntTensor`):
|
|
||||||
The timesteps at which to add noise to the samples.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`:
|
|
||||||
The noisy samples.
|
|
||||||
"""
|
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
|||||||
@@ -127,17 +127,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
The starting `beta` value of inference.
|
The starting `beta` value of inference.
|
||||||
beta_end (`float`, defaults to 0.02):
|
beta_end (`float`, defaults to 0.02):
|
||||||
The final `beta` value.
|
The final `beta` value.
|
||||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
beta_schedule (`str`, defaults to `"linear"`):
|
||||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||||
|
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||||
trained_betas (`np.ndarray`, *optional*):
|
trained_betas (`np.ndarray`, *optional*):
|
||||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||||
solver_order (`int`, defaults to 2):
|
solver_order (`int`, defaults to 2):
|
||||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||||
sampling, and `solver_order=3` for unconditional sampling.
|
sampling, and `solver_order=3` for unconditional sampling.
|
||||||
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||||
Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||||
directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
|
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
|
||||||
Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
|
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
|
||||||
thresholding (`bool`, defaults to `False`):
|
thresholding (`bool`, defaults to `False`):
|
||||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||||
as Stable Diffusion.
|
as Stable Diffusion.
|
||||||
@@ -146,14 +147,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample_max_value (`float`, defaults to 1.0):
|
sample_max_value (`float`, defaults to 1.0):
|
||||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||||
`algorithm_type="dpmsolver++"`.
|
`algorithm_type="dpmsolver++"`.
|
||||||
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
|
algorithm_type (`str`, defaults to `dpmsolver++`):
|
||||||
Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
|
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
||||||
[DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
|
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
||||||
algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
|
paper, and the `dpmsolver++` type implements the algorithms in the
|
||||||
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
||||||
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
|
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||||||
Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
|
solver_type (`str`, defaults to `midpoint`):
|
||||||
for a small number of steps. It is recommended to use `midpoint` solvers.
|
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||||||
|
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||||
lower_order_final (`bool`, defaults to `True`):
|
lower_order_final (`bool`, defaults to `True`):
|
||||||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
||||||
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
||||||
@@ -177,16 +179,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||||
flow_shift (`float`, *optional*, defaults to 1.0):
|
flow_shift (`float`, *optional*, defaults to 1.0):
|
||||||
The shift value for the timestep schedule for flow matching.
|
The shift value for the timestep schedule for flow matching.
|
||||||
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
|
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||||
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
|
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||||
variance_type (`"learned"` or `"learned_range"`, *optional*):
|
variance_type (`str`, *optional*):
|
||||||
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
|
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
||||||
output contains the predicted Gaussian variance.
|
contains the predicted Gaussian variance.
|
||||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
|
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||||
steps_offset (`int`, defaults to 0):
|
steps_offset (`int`, defaults to 0):
|
||||||
@@ -195,10 +197,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||||
use_dynamic_shifting (`bool`, defaults to `False`):
|
|
||||||
Whether to use dynamic shifting for the timestep schedule.
|
|
||||||
time_shift_type (`"exponential"`, defaults to `"exponential"`):
|
|
||||||
The type of time shift to apply when using dynamic shifting.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||||
@@ -210,15 +208,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
num_train_timesteps: int = 1000,
|
num_train_timesteps: int = 1000,
|
||||||
beta_start: float = 0.0001,
|
beta_start: float = 0.0001,
|
||||||
beta_end: float = 0.02,
|
beta_end: float = 0.02,
|
||||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
beta_schedule: str = "linear",
|
||||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||||
solver_order: int = 2,
|
solver_order: int = 2,
|
||||||
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
prediction_type: str = "epsilon",
|
||||||
thresholding: bool = False,
|
thresholding: bool = False,
|
||||||
dynamic_thresholding_ratio: float = 0.995,
|
dynamic_thresholding_ratio: float = 0.995,
|
||||||
sample_max_value: float = 1.0,
|
sample_max_value: float = 1.0,
|
||||||
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
|
algorithm_type: str = "dpmsolver++",
|
||||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
solver_type: str = "midpoint",
|
||||||
lower_order_final: bool = True,
|
lower_order_final: bool = True,
|
||||||
euler_at_final: bool = False,
|
euler_at_final: bool = False,
|
||||||
use_karras_sigmas: Optional[bool] = False,
|
use_karras_sigmas: Optional[bool] = False,
|
||||||
@@ -227,14 +225,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
use_lu_lambdas: Optional[bool] = False,
|
use_lu_lambdas: Optional[bool] = False,
|
||||||
use_flow_sigmas: Optional[bool] = False,
|
use_flow_sigmas: Optional[bool] = False,
|
||||||
flow_shift: Optional[float] = 1.0,
|
flow_shift: Optional[float] = 1.0,
|
||||||
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
|
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||||
lambda_min_clipped: float = -float("inf"),
|
lambda_min_clipped: float = -float("inf"),
|
||||||
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
variance_type: Optional[str] = None,
|
||||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
timestep_spacing: str = "linspace",
|
||||||
steps_offset: int = 0,
|
steps_offset: int = 0,
|
||||||
rescale_betas_zero_snr: bool = False,
|
rescale_betas_zero_snr: bool = False,
|
||||||
use_dynamic_shifting: bool = False,
|
use_dynamic_shifting: bool = False,
|
||||||
time_shift_type: Literal["exponential"] = "exponential",
|
time_shift_type: str = "exponential",
|
||||||
):
|
):
|
||||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||||
@@ -333,22 +331,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
def set_timesteps(
|
def set_timesteps(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: Optional[int] = None,
|
num_inference_steps: int = None,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Union[str, torch.device] = None,
|
||||||
mu: Optional[float] = None,
|
mu: Optional[float] = None,
|
||||||
timesteps: Optional[List[int]] = None,
|
timesteps: Optional[List[int]] = None,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps (`int`, *optional*):
|
num_inference_steps (`int`):
|
||||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
device (`str` or `torch.device`, *optional*):
|
device (`str` or `torch.device`, *optional*):
|
||||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
mu (`float`, *optional*):
|
|
||||||
The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
|
|
||||||
`time_shift_type="exponential"`.
|
|
||||||
timesteps (`List[int]`, *optional*):
|
timesteps (`List[int]`, *optional*):
|
||||||
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
||||||
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
||||||
@@ -508,7 +503,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
def _sigma_to_t(self, sigma, log_sigmas):
|
||||||
"""
|
"""
|
||||||
Convert sigma values to corresponding timestep values through interpolation.
|
Convert sigma values to corresponding timestep values through interpolation.
|
||||||
|
|
||||||
@@ -544,18 +539,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
t = t.reshape(sigma.shape)
|
t = t.reshape(sigma.shape)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
"""
|
|
||||||
Convert sigma values to alpha_t and sigma_t values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sigma (`torch.Tensor`):
|
|
||||||
The sigma value(s) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
|
||||||
A tuple containing (alpha_t, sigma_t) values.
|
|
||||||
"""
|
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@@ -604,21 +588,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||||
"""
|
"""Constructs the noise schedule of Lu et al. (2022)."""
|
||||||
Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
|
|
||||||
Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_lambdas (`torch.Tensor`):
|
|
||||||
The input lambda values to be converted.
|
|
||||||
num_inference_steps (`int`):
|
|
||||||
The number of inference steps to generate the noise schedule for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`:
|
|
||||||
The converted lambda values following the Lu noise schedule.
|
|
||||||
"""
|
|
||||||
|
|
||||||
lambda_min: float = in_lambdas[-1].item()
|
lambda_min: float = in_lambdas[-1].item()
|
||||||
lambda_max: float = in_lambdas[0].item()
|
lambda_max: float = in_lambdas[0].item()
|
||||||
@@ -1098,22 +1069,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
)
|
)
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def index_for_timestep(
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Find the index for a given timestep in the schedule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The timestep for which to find the index.
|
|
||||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
|
||||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`int`:
|
|
||||||
The index of the timestep in the schedule.
|
|
||||||
"""
|
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@@ -1132,13 +1088,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
return step_index
|
return step_index
|
||||||
|
|
||||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The current timestep for which to initialize the step index.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
@@ -1153,7 +1105,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
timestep: Union[int, torch.Tensor],
|
timestep: Union[int, torch.Tensor],
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator=None,
|
||||||
variance_noise: Optional[torch.Tensor] = None,
|
variance_noise: Optional[torch.Tensor] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[SchedulerOutput, Tuple]:
|
) -> Union[SchedulerOutput, Tuple]:
|
||||||
@@ -1163,22 +1115,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_output (`torch.Tensor`):
|
model_output (`torch.Tensor`):
|
||||||
The direct output from the learned diffusion model.
|
The direct output from learned diffusion model.
|
||||||
timestep (`int` or `torch.Tensor`):
|
timestep (`int`):
|
||||||
The current discrete timestep in the diffusion chain.
|
The current discrete timestep in the diffusion chain.
|
||||||
sample (`torch.Tensor`):
|
sample (`torch.Tensor`):
|
||||||
A current instance of a sample created by the diffusion process.
|
A current instance of a sample created by the diffusion process.
|
||||||
generator (`torch.Generator`, *optional*):
|
generator (`torch.Generator`, *optional*):
|
||||||
A random number generator.
|
A random number generator.
|
||||||
variance_noise (`torch.Tensor`, *optional*):
|
variance_noise (`torch.Tensor`):
|
||||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||||
itself. Useful for methods such as [`LEdits++`].
|
itself. Useful for methods such as [`LEdits++`].
|
||||||
return_dict (`bool`, defaults to `True`):
|
return_dict (`bool`):
|
||||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||||
If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||||
tuple is returned where the first element is the sample tensor.
|
tuple is returned where the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -1258,21 +1210,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
original_samples (`torch.Tensor`):
|
|
||||||
The original samples without noise.
|
|
||||||
noise (`torch.Tensor`):
|
|
||||||
The noise to add to the samples.
|
|
||||||
timesteps (`torch.IntTensor`):
|
|
||||||
The timesteps at which to add noise to the samples.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`:
|
|
||||||
The noisy samples.
|
|
||||||
"""
|
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
|||||||
@@ -413,17 +413,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
"""
|
|
||||||
Convert sigma values to alpha_t and sigma_t values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sigma (`torch.Tensor`):
|
|
||||||
The sigma value(s) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
|
||||||
A tuple containing (alpha_t, sigma_t) values.
|
|
||||||
"""
|
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
|
|||||||
@@ -491,17 +491,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
"""
|
|
||||||
Convert sigma values to alpha_t and sigma_t values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sigma (`torch.Tensor`):
|
|
||||||
The sigma value(s) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
|
||||||
A tuple containing (alpha_t, sigma_t) values.
|
|
||||||
"""
|
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@@ -1090,22 +1079,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Find the index for a given timestep in the schedule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The timestep for which to find the index.
|
|
||||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
|
||||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`int`:
|
|
||||||
The index of the timestep in the schedule.
|
|
||||||
"""
|
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@@ -1128,10 +1102,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The current timestep for which to initialize the step index.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
@@ -1234,21 +1204,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
original_samples (`torch.Tensor`):
|
|
||||||
The original samples without noise.
|
|
||||||
noise (`torch.Tensor`):
|
|
||||||
The noise to add to the samples.
|
|
||||||
timesteps (`torch.IntTensor`):
|
|
||||||
The timesteps at which to add noise to the samples.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`:
|
|
||||||
The noisy samples.
|
|
||||||
"""
|
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
|||||||
@@ -578,22 +578,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Find the index for a given timestep in the schedule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The timestep for which to find the index.
|
|
||||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
|
||||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`int`:
|
|
||||||
The index of the timestep in the schedule.
|
|
||||||
"""
|
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@@ -616,10 +601,6 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The current timestep for which to initialize the step index.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
|
|||||||
@@ -423,17 +423,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
"""
|
|
||||||
Convert sigma values to alpha_t and sigma_t values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sigma (`torch.Tensor`):
|
|
||||||
The sigma value(s) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
|
||||||
A tuple containing (alpha_t, sigma_t) values.
|
|
||||||
"""
|
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@@ -1114,22 +1103,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Find the index for a given timestep in the schedule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The timestep for which to find the index.
|
|
||||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
|
||||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`int`:
|
|
||||||
The index of the timestep in the schedule.
|
|
||||||
"""
|
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@@ -1152,10 +1126,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The current timestep for which to initialize the step index.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
|
|||||||
@@ -513,17 +513,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
"""
|
|
||||||
Convert sigma values to alpha_t and sigma_t values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sigma (`torch.Tensor`):
|
|
||||||
The sigma value(s) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
|
||||||
A tuple containing (alpha_t, sigma_t) values.
|
|
||||||
"""
|
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@@ -995,22 +984,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Find the index for a given timestep in the schedule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The timestep for which to find the index.
|
|
||||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
|
||||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`int`:
|
|
||||||
The index of the timestep in the schedule.
|
|
||||||
"""
|
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@@ -1033,10 +1007,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
Args:
|
|
||||||
timestep (`int` or `torch.Tensor`):
|
|
||||||
The current timestep for which to initialize the step index.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
@@ -1149,21 +1119,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
original_samples (`torch.Tensor`):
|
|
||||||
The original samples without noise.
|
|
||||||
noise (`torch.Tensor`):
|
|
||||||
The noise to add to the samples.
|
|
||||||
timesteps (`torch.IntTensor`):
|
|
||||||
The timesteps at which to add noise to the samples.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`:
|
|
||||||
The noisy samples.
|
|
||||||
"""
|
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
|||||||
@@ -1,306 +0,0 @@
|
|||||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
|
|
||||||
|
|
||||||
from diffusers import (
|
|
||||||
AutoencoderKL,
|
|
||||||
FlowMatchEulerDiscreteScheduler,
|
|
||||||
ZImagePipeline,
|
|
||||||
ZImageTransformer2DModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ...testing_utils import torch_device
|
|
||||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
|
||||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
|
||||||
|
|
||||||
|
|
||||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
|
||||||
# Cannot use enable_full_determinism() which sets it to True
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
||||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
|
||||||
torch.use_deterministic_algorithms(False)
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
torch.backends.cudnn.benchmark = False
|
|
||||||
if hasattr(torch.backends, "cuda"):
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
|
||||||
|
|
||||||
# Note: Some tests (test_float16_inference, test_save_load_float16) may fail in full suite
|
|
||||||
# due to RopeEmbedder cache state pollution between tests. They pass when run individually.
|
|
||||||
# This is a known test isolation issue, not a functional bug.
|
|
||||||
|
|
||||||
|
|
||||||
class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|
||||||
pipeline_class = ZImagePipeline
|
|
||||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
|
||||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
|
||||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
|
||||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
|
||||||
required_optional_params = frozenset(
|
|
||||||
[
|
|
||||||
"num_inference_steps",
|
|
||||||
"generator",
|
|
||||||
"latents",
|
|
||||||
"return_dict",
|
|
||||||
"callback_on_step_end",
|
|
||||||
"callback_on_step_end_tensor_inputs",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
supports_dduf = False
|
|
||||||
test_xformers_attention = False
|
|
||||||
test_layerwise_casting = True
|
|
||||||
test_group_offloading = True
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.manual_seed(0)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(0)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.manual_seed(0)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(0)
|
|
||||||
|
|
||||||
def get_dummy_components(self):
|
|
||||||
torch.manual_seed(0)
|
|
||||||
transformer = ZImageTransformer2DModel(
|
|
||||||
all_patch_size=(2,),
|
|
||||||
all_f_patch_size=(1,),
|
|
||||||
in_channels=16,
|
|
||||||
dim=32,
|
|
||||||
n_layers=2,
|
|
||||||
n_refiner_layers=1,
|
|
||||||
n_heads=2,
|
|
||||||
n_kv_heads=2,
|
|
||||||
norm_eps=1e-5,
|
|
||||||
qk_norm=True,
|
|
||||||
cap_feat_dim=16,
|
|
||||||
rope_theta=256.0,
|
|
||||||
t_scale=1000.0,
|
|
||||||
axes_dims=[8, 4, 4],
|
|
||||||
axes_lens=[256, 32, 32],
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
vae = AutoencoderKL(
|
|
||||||
in_channels=3,
|
|
||||||
out_channels=3,
|
|
||||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
|
||||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
|
||||||
block_out_channels=[32, 64],
|
|
||||||
layers_per_block=1,
|
|
||||||
latent_channels=16,
|
|
||||||
norm_num_groups=32,
|
|
||||||
sample_size=32,
|
|
||||||
scaling_factor=0.3611,
|
|
||||||
shift_factor=0.1159,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
config = Qwen3Config(
|
|
||||||
hidden_size=16,
|
|
||||||
intermediate_size=16,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=2,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
vocab_size=151936,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
)
|
|
||||||
text_encoder = Qwen3Model(config)
|
|
||||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
|
||||||
|
|
||||||
components = {
|
|
||||||
"transformer": transformer,
|
|
||||||
"vae": vae,
|
|
||||||
"scheduler": scheduler,
|
|
||||||
"text_encoder": text_encoder,
|
|
||||||
"tokenizer": tokenizer,
|
|
||||||
}
|
|
||||||
return components
|
|
||||||
|
|
||||||
def get_dummy_inputs(self, device, seed=0):
|
|
||||||
if str(device).startswith("mps"):
|
|
||||||
generator = torch.manual_seed(seed)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"prompt": "dance monkey",
|
|
||||||
"negative_prompt": "bad quality",
|
|
||||||
"generator": generator,
|
|
||||||
"num_inference_steps": 2,
|
|
||||||
"guidance_scale": 3.0,
|
|
||||||
"cfg_normalization": False,
|
|
||||||
"cfg_truncation": 1.0,
|
|
||||||
"height": 32,
|
|
||||||
"width": 32,
|
|
||||||
"max_sequence_length": 16,
|
|
||||||
"output_type": "pt",
|
|
||||||
}
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def test_inference(self):
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
pipe = self.pipeline_class(**components)
|
|
||||||
pipe.to(device)
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
image = pipe(**inputs).images
|
|
||||||
generated_image = image[0]
|
|
||||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732])
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
generated_slice = generated_image.flatten()
|
|
||||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
|
||||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=5e-2))
|
|
||||||
|
|
||||||
def test_inference_batch_single_identical(self):
|
|
||||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
|
||||||
|
|
||||||
def test_num_images_per_prompt(self):
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
sig = inspect.signature(self.pipeline_class.__call__)
|
|
||||||
|
|
||||||
if "num_images_per_prompt" not in sig.parameters:
|
|
||||||
return
|
|
||||||
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
pipe = self.pipeline_class(**components)
|
|
||||||
pipe = pipe.to(torch_device)
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
batch_sizes = [1, 2]
|
|
||||||
num_images_per_prompts = [1, 2]
|
|
||||||
|
|
||||||
for batch_size in batch_sizes:
|
|
||||||
for num_images_per_prompt in num_images_per_prompts:
|
|
||||||
inputs = self.get_dummy_inputs(torch_device)
|
|
||||||
|
|
||||||
for key in inputs.keys():
|
|
||||||
if key in self.batch_params:
|
|
||||||
inputs[key] = batch_size * [inputs[key]]
|
|
||||||
|
|
||||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
|
||||||
|
|
||||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
|
||||||
|
|
||||||
del pipe
|
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def test_attention_slicing_forward_pass(
|
|
||||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
|
||||||
):
|
|
||||||
if not self.test_attention_slicing:
|
|
||||||
return
|
|
||||||
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
pipe = self.pipeline_class(**components)
|
|
||||||
for component in pipe.components.values():
|
|
||||||
if hasattr(component, "set_default_attn_processor"):
|
|
||||||
component.set_default_attn_processor()
|
|
||||||
pipe.to(torch_device)
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
generator_device = "cpu"
|
|
||||||
inputs = self.get_dummy_inputs(generator_device)
|
|
||||||
output_without_slicing = pipe(**inputs)[0]
|
|
||||||
|
|
||||||
pipe.enable_attention_slicing(slice_size=1)
|
|
||||||
inputs = self.get_dummy_inputs(generator_device)
|
|
||||||
output_with_slicing1 = pipe(**inputs)[0]
|
|
||||||
|
|
||||||
pipe.enable_attention_slicing(slice_size=2)
|
|
||||||
inputs = self.get_dummy_inputs(generator_device)
|
|
||||||
output_with_slicing2 = pipe(**inputs)[0]
|
|
||||||
|
|
||||||
if test_max_difference:
|
|
||||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
|
||||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
|
||||||
self.assertLess(
|
|
||||||
max(max_diff1, max_diff2),
|
|
||||||
expected_max_diff,
|
|
||||||
"Attention slicing should not affect the inference results",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
|
||||||
generator_device = "cpu"
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
|
|
||||||
pipe = self.pipeline_class(**components)
|
|
||||||
pipe.to("cpu")
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
# Without tiling
|
|
||||||
inputs = self.get_dummy_inputs(generator_device)
|
|
||||||
inputs["height"] = inputs["width"] = 128
|
|
||||||
output_without_tiling = pipe(**inputs)[0]
|
|
||||||
|
|
||||||
# With tiling (standard AutoencoderKL doesn't accept parameters)
|
|
||||||
pipe.vae.enable_tiling()
|
|
||||||
inputs = self.get_dummy_inputs(generator_device)
|
|
||||||
inputs["height"] = inputs["width"] = 128
|
|
||||||
output_with_tiling = pipe(**inputs)[0]
|
|
||||||
|
|
||||||
self.assertLess(
|
|
||||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
|
||||||
expected_diff_max,
|
|
||||||
"VAE tiling should not affect the inference results",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=5e-4):
|
|
||||||
# Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
|
|
||||||
super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
|
|
||||||
|
|
||||||
def test_group_offloading_inference(self):
|
|
||||||
# Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
|
|
||||||
self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
|
|
||||||
|
|
||||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.manual_seed(0)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(0)
|
|
||||||
super().test_save_load_float16(expected_max_diff=expected_max_diff)
|
|
||||||
Reference in New Issue
Block a user