Compare commits

...

22 Commits

Author SHA1 Message Date
yiyixuxu
7a04604553 style 2025-03-21 17:01:04 +01:00
Junsong Chen
a220997e11 [SANA-Sprint] remove used multi-scale bin (#11131)
* change sample prompt;

* only 1024px is supported;
2025-03-21 05:57:29 -10:00
YiYi Xu
94d87d5186 Apply suggestions from code review 2025-03-20 17:04:17 -10:00
yiyixuxu
c3e107f0a9 up 2025-03-21 03:54:33 +01:00
yiyixuxu
c4d049c054 add tests 2025-03-21 03:06:56 +01:00
yiyixuxu
eae8ed71a2 update docstring example 2025-03-21 00:10:54 +01:00
yiyixuxu
8c07fccb6d up 2025-03-21 00:05:45 +01:00
YiYi Xu
3734af8eac Apply suggestions from code review
Co-authored-by: Aryan <aryan@huggingface.co>
2025-03-20 12:56:21 -10:00
yiyixuxu
1de087e16f add to torctree 2025-03-20 23:54:04 +01:00
Sayak Paul
8e4f71177e [docs] add a note about max_timesteps (#11122)
add a note about max_timesteps
2025-03-20 06:38:52 -10:00
github-actions[bot]
9cd5f1e66d Apply style fixes 2025-03-20 10:49:03 +00:00
yiyixuxu
da3c9172df Merge branch 'sana-sprint' of github.com:huggingface/diffusers into sana-sprint 2025-03-20 11:45:01 +01:00
yiyixuxu
be73b5960c up upp 2025-03-20 11:44:45 +01:00
yiyixuxu
8070495df1 up 2025-03-20 11:40:13 +01:00
Junsong Chen
4e5a9efdc2 update conversion script for SANA-1.5 and SANA-Sprint (#11082)
* 1. update conversion script for sana1.5;
2. add conversion script for sana-sprint;

* seperate sana and sana-sprint conversion scripts;

* update for upstream

* fix the } bug

* add a doc for SanaSprintPipeline;

* minor update;

* make style && make quality
2025-03-19 22:25:46 -10:00
yiyixuxu
398ca0c938 remove unused __init__ arg for scm scheduler 2025-03-20 02:55:00 +01:00
yiyixuxu
4eef82b2c9 pipeline_sana_scm -> pipeline_sana_sprint 2025-03-20 02:44:06 +01:00
yiyixuxu
5b19b22685 copies 2025-03-19 21:37:57 +01:00
yiyixuxu
0d6309ae00 style 2025-03-19 21:33:06 +01:00
yiyixuxu
ae4c3fda10 add conversion sript 2025-03-19 21:32:52 +01:00
Junsong Chen
9714187c30 change name from SanaSCMPipeline to SanaSprintPipeline. (#11076) 2025-03-16 20:39:59 -10:00
yiyixuxu
c952370cb4 first commit 2025-03-17 02:14:55 +01:00
15 changed files with 1996 additions and 124 deletions

View File

@@ -496,6 +496,8 @@
title: PixArt-Σ
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/sana_sprint
title: Sana Sprint
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion

View File

@@ -0,0 +1,100 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaSprintPipeline
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
The abstract from the paper is:
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
Available models:
| Model | Recommended dtype |
|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModel.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
pipeline = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.bfloat16,
device_map="balanced",
)
prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("sana.png")
```
## Setting `max_timesteps`
Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
## SanaSprintPipeline
[[autodoc]] SanaSprintPipeline
- all
- __call__
## SanaPipelineOutput
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput

View File

@@ -16,7 +16,9 @@ from diffusers import (
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaPipeline,
SanaSprintPipeline,
SanaTransformer2DModel,
SCMScheduler,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
@@ -25,6 +27,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
@@ -72,15 +75,42 @@ def main(args):
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Handle different time embedding structure based on model type
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# For Sana Sprint, the time embedding structure is different
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Guidance embedder for Sana Sprint
converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
"cfg_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
"cfg_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
else:
# Original Sana time embedding structure
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
"t_embedder.mlp.0.bias"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
"t_embedder.mlp.2.bias"
)
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
@@ -96,14 +126,22 @@ def main(args):
flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
layer_num = 28
elif args.model_type == "SanaMS_4800M_P1_D60":
layer_num = 60
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
qk_norm = (
"rms_norm_across_heads"
if args.model_type
in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
else None
)
for depth in range(layer_num):
# Transformer blocks.
@@ -117,6 +155,14 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
if qk_norm is not None:
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
@@ -154,6 +200,14 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
if qk_norm is not None:
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
@@ -169,24 +223,37 @@ def main(args):
# Transformer
with CTX():
transformer = SanaTransformer2DModel(
in_channels=32,
out_channels=32,
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
num_layers=model_kwargs[args.model_type]["num_layers"],
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
caption_channels=2304,
mlp_ratio=2.5,
attention_bias=False,
sample_size=args.image_size // 32,
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)
transformer_kwargs = {
"in_channels": 32,
"out_channels": 32,
"num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
"attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
"num_layers": model_kwargs[args.model_type]["num_layers"],
"num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
"cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
"cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
"caption_channels": 2304,
"mlp_ratio": 2.5,
"attention_bias": False,
"sample_size": args.image_size // 32,
"patch_size": 1,
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"interpolation_scale": interpolation_scale[args.image_size],
}
# Add qk_norm parameter for Sana Sprint
if args.model_type in [
"SanaMS1.5_1600M_P1_D20",
"SanaMS1.5_4800M_P1_D60",
"SanaSprint_600M_P1_D28",
"SanaSprint_1600M_P1_D20",
]:
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
transformer_kwargs["guidance_embeds"] = True
transformer = SanaTransformer2DModel(**transformer_kwargs)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_state_dict)
@@ -196,6 +263,8 @@ def main(args):
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
state_dict.pop("logvar_linear.weight")
state_dict.pop("logvar_linear.bias")
except KeyError:
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
@@ -210,47 +279,75 @@ def main(args):
print(
colored(
f"Only saving transformer model of {args.model_type}. "
f"Set --save_full_pipeline to save the whole SanaPipeline",
f"Set --save_full_pipeline to save the whole Pipeline",
"green",
attrs=["bold"],
)
)
transformer.save_pretrained(
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
)
else:
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "google/gemma-2-2b-it"
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_model_path, torch_dtype=torch.bfloat16
).get_decoder()
# Scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
# Choose the appropriate pipeline and scheduler based on model type
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
if args.scheduler_type != "scm":
print(
colored(
f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
"yellow",
attrs=["bold"],
)
)
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
# SCM Scheduler for Sana Sprint
scheduler_config = {
"num_train_timesteps": 1000,
"prediction_type": "trigflow",
"sigma_data": 0.5,
}
scheduler = SCMScheduler(**scheduler_config)
pipe = SanaSprintPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
else:
# Original Sana scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
DTYPE_MAPPING = {
@@ -259,12 +356,6 @@ DTYPE_MAPPING = {
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -281,10 +372,23 @@ if __name__ == "__main__":
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
"--model_type",
default="SanaMS_1600M_P1_D20",
type=str,
choices=[
"SanaMS_1600M_P1_D20",
"SanaMS_600M_P1_D28",
"SanaMS_4800M_P1_D60",
"SanaSprint_1600M_P1_D20",
"SanaSprint_600M_P1_D28",
],
)
parser.add_argument(
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
"--scheduler_type",
default="flow-dpm_solver",
type=str,
choices=["flow-dpm_solver", "flow-euler", "scm"],
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
@@ -309,10 +413,41 @@ if __name__ == "__main__":
"cross_attention_dim": 1152,
"num_layers": 28,
},
"SanaMS1.5_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
"SanaMS1.5__4800M_P1_D60": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 60,
},
"SanaSprint_600M_P1_D28": {
"num_attention_heads": 36,
"attention_head_dim": 32,
"num_cross_attention_heads": 16,
"cross_attention_head_dim": 72,
"cross_attention_dim": 1152,
"num_layers": 28,
},
"SanaSprint_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
main(args)

View File

@@ -271,6 +271,7 @@ else:
"RePaintScheduler",
"SASolverScheduler",
"SchedulerMixin",
"SCMScheduler",
"ScoreSdeVeScheduler",
"TCDScheduler",
"UnCLIPScheduler",
@@ -421,6 +422,7 @@ else:
"ReduxImageEncoder",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -834,6 +836,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
RePaintScheduler,
SASolverScheduler,
SchedulerMixin,
SCMScheduler,
ScoreSdeVeScheduler,
TCDScheduler,
UnCLIPScheduler,
@@ -965,6 +968,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ReduxImageEncoder,
SanaPAGPipeline,
SanaPipeline,
SanaSprintPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,

View File

@@ -6020,6 +6020,11 @@ class SanaLinearAttnProcessor2_0:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))

View File

@@ -15,6 +15,7 @@
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
@@ -23,10 +24,9 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
from ..attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor2_0,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
@@ -96,6 +96,95 @@ class SanaModulatedNorm(nn.Module):
return hidden_states
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
guidance_proj = self.guidance_condition_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
conditioning = timesteps_emb + guidance_emb
return self.linear(self.silu(conditioning)), conditioning
class SanaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -115,6 +204,7 @@ class SanaTransformerBlock(nn.Module):
norm_eps: float = 1e-6,
attention_out_bias: bool = True,
mlp_ratio: float = 2.5,
qk_norm: Optional[str] = None,
) -> None:
super().__init__()
@@ -124,6 +214,8 @@ class SanaTransformerBlock(nn.Module):
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
kv_heads=num_attention_heads if qk_norm is not None else None,
qk_norm=qk_norm,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
@@ -135,13 +227,15 @@ class SanaTransformerBlock(nn.Module):
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
qk_norm=qk_norm,
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
cross_attention_dim=cross_attention_dim,
heads=num_cross_attention_heads,
dim_head=cross_attention_head_dim,
dropout=dropout,
bias=True,
out_bias=attention_out_bias,
processor=AttnProcessor2_0(),
processor=SanaAttnProcessor2_0(),
)
# 3. Feed-forward
@@ -258,6 +352,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
guidance_embeds: bool = False,
guidance_embeds_scale: float = 0.1,
qk_norm: Optional[str] = None,
) -> None:
super().__init__()
@@ -276,7 +373,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
)
# 2. Additional condition embeddings
self.time_embed = AdaLayerNormSingle(inner_dim)
if guidance_embeds:
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
else:
self.time_embed = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
@@ -296,6 +396,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
qk_norm=qk_norm,
)
for _ in range(num_layers)
]
@@ -372,7 +473,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
timestep: torch.Tensor,
guidance: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -423,9 +525,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = self.patch_embed(hidden_states)
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if guidance is not None:
timestep, embedded_timestep = self.time_embed(
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
)
else:
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])

View File

@@ -280,7 +280,7 @@ else:
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["sana"] = ["SanaPipeline"]
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_audio"] = [
@@ -651,7 +651,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .sana import SanaPipeline
from .sana import SanaPipeline, SanaSprintPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel

View File

@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_sana"] = ["SanaPipeline"]
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_sana import SanaPipeline
from .pipeline_sana_sprint import SanaSprintPipeline
else:
import sys

View File

@@ -248,6 +248,64 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
"""
self.vae.disable_tiling()
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: torch.device,
dtype: torch.dtype,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_sequence_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -296,6 +354,13 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if device is None:
device = self._execution_device
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
@@ -320,43 +385,18 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
select_index = [0] + list(range(-max_length + 1, 0))
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0][:, select_index]
prompt_embeds = prompt_embeds[:, select_index]
prompt_attention_mask = prompt_attention_mask[:, select_index]
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -366,25 +406,15 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=negative_prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=False,
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method

View File

@@ -0,0 +1,889 @@
# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, SanaTransformer2DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
USE_PEFT_BACKEND,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
from .pipeline_output import SanaPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import SanaSprintPipeline
>>> pipe = SanaSprintPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0]
>>> image[0].save("output.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
r"""
Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
"""
# fmt: off
bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
# fmt: on
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC,
transformer: SanaTransformer2DModel,
scheduler: DPMSolverMultistepScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
if hasattr(self, "vae") and self.vae is not None
else 32
)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: torch.device,
dtype: torch.dtype,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_sequence_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
if device is None:
device = self._execution_device
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
# See Section 3.1. of the paper.
max_length = max_sequence_length
select_index = [0] + list(range(-max_length + 1, 0))
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
)
prompt_embeds = prompt_embeds[:, select_index]
prompt_attention_mask = prompt_attention_mask[:, select_index]
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
if self.text_encoder is not None:
if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
num_inference_steps,
timesteps,
max_timesteps,
intermediate_timesteps,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
if timesteps is not None and max_timesteps is not None:
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
if timesteps is None and max_timesteps is None:
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
if intermediate_timesteps is not None and num_inference_steps != 2:
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 2,
timesteps: List[int] = None,
max_timesteps: float = 1.57080,
intermediate_timesteps: float = 1.3,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: int = 1024,
width: int = 1024,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
clean_caption: bool = False,
use_resolution_binning: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 300,
complex_human_instruction: List[str] = [
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
"Here are examples of how to transform or refine prompts:",
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
"User Prompt: ",
],
) -> Union[SanaPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
num_inference_steps (`int`, *optional*, defaults to 20):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
max_timesteps (`float`, *optional*, defaults to 1.57080):
The maximum timestep value used in the SCM scheduler.
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
attention_kwargs:
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `300`):
Maximum sequence length to use with the `prompt`.
complex_human_instruction (`List[str]`, *optional*):
Instructions for complex human attention:
https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
Examples:
Returns:
[`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
if use_resolution_binning:
if self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
max_timesteps=max_timesteps,
intermediate_timesteps=intermediate_timesteps,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
) = self.encode_prompt(
prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
lora_scale=lora_scale,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas=None,
max_timesteps=max_timesteps,
intermediate_timesteps=intermediate_timesteps,
)
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(0)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
torch.float32,
device,
generator,
latents,
)
latents = latents * self.scheduler.config.sigma_data
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
guidance = guidance * self.transformer.config.guidance_embeds_scale
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
timesteps = timesteps[:-1]
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
latents_model_input = latents / self.scheduler.config.sigma_data
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
latent_model_input = latents_model_input * torch.sqrt(
scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
)
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
guidance=guidance,
timestep=scm_timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]
noise_pred = (
(1 - 2 * scm_timestep_expanded) * latent_model_input
+ (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
# compute previous image: x_t -> x_t-1
latents, denoised = self.scheduler.step(
noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = denoised / self.scheduler.config.sigma_data
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return SanaPipelineOutput(images=image)

View File

@@ -68,6 +68,7 @@ else:
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
_import_structure["scheduling_scm"] = ["SCMScheduler"]
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
_import_structure["scheduling_tcd"] = ["TCDScheduler"]
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
@@ -168,13 +169,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
from .scheduling_sasolver import SASolverScheduler
from .scheduling_scm import SCMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_tcd import TCDScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()

View File

@@ -0,0 +1,265 @@
# # Copyright 2024 Sana-Sprint Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..schedulers.scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->SCM
class SCMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
class SCMScheduler(SchedulerMixin, ConfigMixin):
"""
`SCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass
documentation for the generic methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `trigflow`):
Prediction type of the scheduler function. Currently only supports "trigflow".
sigma_data (`float`, defaults to 0.5):
The standard deviation of the noise added during multi-step inference.
"""
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
prediction_type: str = "trigflow",
sigma_data: float = 0.5,
):
"""
Initialize the SCM scheduler.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `trigflow`):
Prediction type of the scheduler function. Currently only supports "trigflow".
sigma_data (`float`, defaults to 0.5):
The standard deviation of the noise added during multi-step inference.
"""
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self._step_index = None
self._begin_index = None
@property
def step_index(self):
return self._step_index
@property
def begin_index(self):
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: int,
timesteps: torch.Tensor = None,
device: Union[str, torch.device] = None,
max_timesteps: float = 1.57080,
intermediate_timesteps: float = 1.3,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
timesteps (`torch.Tensor`, *optional*):
Custom timesteps to use for the denoising process.
max_timesteps (`float`, defaults to 1.57080):
The maximum timestep value used in the SCM scheduler.
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
"""
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
if timesteps is not None and max_timesteps is not None:
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
if timesteps is None and max_timesteps is None:
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
if intermediate_timesteps is not None and num_inference_steps != 2:
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
self.num_inference_steps = num_inference_steps
if timesteps is not None:
if isinstance(timesteps, list):
self.timesteps = torch.tensor(timesteps, device=device).float()
elif isinstance(timesteps, torch.Tensor):
self.timesteps = timesteps.to(device).float()
else:
raise ValueError(f"Unsupported timesteps type: {type(timesteps)}")
elif intermediate_timesteps is not None:
self.timesteps = torch.tensor([max_timesteps, intermediate_timesteps, 0], device=device).float()
else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
print(f"Set timesteps: {self.timesteps}")
self._step_index = None
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def step(
self,
model_output: torch.FloatTensor,
timestep: float,
sample: torch.FloatTensor,
generator: torch.Generator = None,
return_dict: bool = True,
) -> Union[SCMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_scm.SCMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SCMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_scm.SCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# 2. compute alphas, betas
t = self.timesteps[self.step_index + 1]
s = self.timesteps[self.step_index]
# 4. Different Parameterization:
parameterization = self.config.prediction_type
if parameterization == "trigflow":
pred_x0 = torch.cos(s) * sample - torch.sin(s) * model_output
else:
raise ValueError(f"Unsupported parameterization: {parameterization}")
# 5. Sample z ~ N(0, I), For MultiStep Inference
# Noise is not used for one-step sampling.
if len(self.timesteps) > 1:
noise = (
randn_tensor(model_output.shape, device=model_output.device, generator=generator)
* self.config.sigma_data
)
prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise
else:
prev_sample = pred_x0
self._step_index += 1
if not return_dict:
return (prev_sample, pred_x0)
return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -1834,6 +1834,21 @@ class SchedulerMixin(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class SCMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1502,6 +1502,21 @@ class SanaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class SanaSprintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,302 @@
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = SanaSprintPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt", "negative_prompt_embeds"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"}
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"}
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SanaTransformer2DModel(
patch_size=1,
in_channels=4,
out_channels=4,
num_layers=1,
num_attention_heads=2,
attention_head_dim=4,
num_cross_attention_heads=2,
cross_attention_head_dim=4,
cross_attention_dim=8,
caption_channels=8,
sample_size=32,
qk_norm="rms_norm_across_heads",
guidance_embeds=True,
)
torch.manual_seed(0)
vae = AutoencoderDC(
in_channels=3,
latent_channels=4,
attention_head_dim=2,
encoder_block_types=(
"ResBlock",
"EfficientViTBlock",
),
decoder_block_types=(
"ResBlock",
"EfficientViTBlock",
),
encoder_block_out_channels=(8, 8),
decoder_block_out_channels=(8, 8),
encoder_qkv_multiscales=((), (5,)),
decoder_qkv_multiscales=((), (5,)),
encoder_layers_per_block=(1, 1),
decoder_layers_per_block=[1, 1],
downsample_block_type="conv",
upsample_block_type="interpolate",
decoder_norm_types="rms_norm",
decoder_act_fns="silu",
scaling_factor=0.41407,
)
torch.manual_seed(0)
scheduler = SCMScheduler()
torch.manual_seed(0)
text_encoder_config = Gemma2Config(
head_dim=16,
hidden_size=8,
initializer_range=0.02,
intermediate_size=64,
max_position_embeddings=8192,
model_type="gemma2",
num_attention_heads=2,
num_hidden_layers=1,
num_key_value_heads=2,
vocab_size=8,
attn_implementation="eager",
)
text_encoder = Gemma2Model(text_encoder_config)
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
"complex_human_instruction": None,
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_image = torch.randn(3, 32, 32)
max_diff = np.abs(generated_image - expected_image).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
# TODO(aryan): Create a dummy gemma model with smol vocab size
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_consistent(self):
pass
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_single_identical(self):
pass
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)