mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 22:44:38 +08:00
Compare commits
50 Commits
custom-blo
...
Photoroom-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53a2a7aff5 | ||
|
|
8de7b9247a | ||
|
|
5c54baacb7 | ||
|
|
9e8279e1fe | ||
|
|
aed1f19396 | ||
|
|
1354f450e6 | ||
|
|
fdc8e34533 | ||
|
|
d5ffd35d70 | ||
|
|
adeb45e0b3 | ||
|
|
7d12474c24 | ||
|
|
0ef0dc6837 | ||
|
|
836cd12a18 | ||
|
|
caf64407cb | ||
|
|
c469a7a916 | ||
|
|
9aa47ce6c3 | ||
|
|
a8fa52ba2a | ||
|
|
34a74928ac | ||
|
|
0fdfd27ecb | ||
|
|
574f8fd10a | ||
|
|
6e05172682 | ||
|
|
d0c029f15d | ||
|
|
015774399e | ||
|
|
5f99168def | ||
|
|
c329c8f667 | ||
|
|
bb36735379 | ||
|
|
8ee17d20b3 | ||
|
|
be1d14658e | ||
|
|
c951adef45 | ||
|
|
a74e0b726a | ||
|
|
ffe3501c1c | ||
|
|
2077252947 | ||
|
|
de1ceaf07a | ||
|
|
3c60c9230e | ||
|
|
af8882d7e6 | ||
|
|
5f0bf0181f | ||
|
|
ae44d845b6 | ||
|
|
12dbabe607 | ||
|
|
ec70e3fdc0 | ||
|
|
6634113ef6 | ||
|
|
b07d1c8799 | ||
|
|
a9e301366a | ||
|
|
5886925346 | ||
|
|
25a0061d65 | ||
|
|
6284b9d062 | ||
|
|
60d918d79b | ||
|
|
b327b36ad9 | ||
|
|
14903ee599 | ||
|
|
d71ddd0079 | ||
|
|
6a66fbd2c4 | ||
|
|
e487660e05 |
@@ -544,6 +544,8 @@
|
||||
title: PAG
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/photon
|
||||
title: Photon
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
|
||||
131
docs/source/en/api/pipelines/photon.md
Normal file
131
docs/source/en/api/pipelines/photon.md
Normal file
@@ -0,0 +1,131 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Photon
|
||||
|
||||
|
||||
Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
|
||||
|
||||
## Available models
|
||||
|
||||
Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
|
||||
|
||||
|
||||
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|
||||
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
|
||||
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information.
|
||||
|
||||
## Loading the pipeline
|
||||
|
||||
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
|
||||
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
|
||||
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
image.save("photon_output.png")
|
||||
```
|
||||
|
||||
### Manual Component Loading
|
||||
|
||||
Load components individually to customize the pipeline for instance to use quantized models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
from diffusers.models import AutoencoderKL, AutoencoderDC
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from transformers import T5GemmaModel, GemmaTokenizerFast
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
# Load transformer
|
||||
transformer = PhotonTransformer2DModel.from_pretrained(
|
||||
"checkpoints/photon-512-t2i-sft",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"checkpoints/photon-512-t2i-sft", subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Load T5Gemma text encoder
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
|
||||
# Load VAE - choose either Flux VAE or DC-AE
|
||||
# Flux VAE
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
|
||||
subfolder="vae",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = PhotonPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae
|
||||
)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
For memory-constrained environments:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
|
||||
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
|
||||
|
||||
# Or use sequential CPU offload for even lower memory
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
## PhotonPipeline
|
||||
|
||||
[[autodoc]] PhotonPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## PhotonPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput
|
||||
345
scripts/convert_photon_to_diffusers.py
Normal file
345
scripts/convert_photon_to_diffusers.py
Normal file
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to convert Photon checkpoint from original codebase to diffusers format.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonBase:
|
||||
context_in_dim: int = 2304
|
||||
hidden_size: int = 1792
|
||||
mlp_ratio: float = 3.5
|
||||
num_heads: int = 28
|
||||
depth: int = 16
|
||||
axes_dim: Tuple[int, int] = (32, 32)
|
||||
theta: int = 10_000
|
||||
time_factor: float = 1000.0
|
||||
time_max_period: int = 10_000
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonFlux(PhotonBase):
|
||||
in_channels: int = 16
|
||||
patch_size: int = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonDCAE(PhotonBase):
|
||||
in_channels: int = 32
|
||||
patch_size: int = 1
|
||||
|
||||
|
||||
def build_config(vae_type: str) -> Tuple[dict, int]:
|
||||
if vae_type == "flux":
|
||||
cfg = PhotonFlux()
|
||||
elif vae_type == "dc-ae":
|
||||
cfg = PhotonDCAE()
|
||||
else:
|
||||
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
|
||||
|
||||
config_dict = asdict(cfg)
|
||||
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
|
||||
return config_dict
|
||||
|
||||
|
||||
def create_parameter_mapping(depth: int) -> dict:
|
||||
"""Create mapping from old parameter names to new diffusers names."""
|
||||
|
||||
# Key mappings for structural changes
|
||||
mapping = {}
|
||||
|
||||
# Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention)
|
||||
for i in range(depth):
|
||||
# QKV projections moved to attention module
|
||||
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
|
||||
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
|
||||
|
||||
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
|
||||
# K norm for text tokens moved to attention module
|
||||
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
|
||||
# Attention output projection
|
||||
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
|
||||
"""Convert old checkpoint parameters to new diffusers format."""
|
||||
|
||||
print("Converting checkpoint parameters...")
|
||||
|
||||
mapping = create_parameter_mapping(depth)
|
||||
converted_state_dict = {}
|
||||
|
||||
for key, value in old_state_dict.items():
|
||||
new_key = key
|
||||
|
||||
# Apply specific mappings if needed
|
||||
if key in mapping:
|
||||
new_key = mapping[key]
|
||||
print(f" Mapped: {key} -> {new_key}")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
print(f"✓ Converted {len(converted_state_dict)} parameters")
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel:
|
||||
"""Create and load PhotonTransformer2DModel from old checkpoint."""
|
||||
|
||||
print(f"Loading checkpoint from: {checkpoint_path}")
|
||||
|
||||
# Load old checkpoint
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if isinstance(old_checkpoint, dict):
|
||||
if "model" in old_checkpoint:
|
||||
state_dict = old_checkpoint["model"]
|
||||
elif "state_dict" in old_checkpoint:
|
||||
state_dict = old_checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = old_checkpoint
|
||||
else:
|
||||
state_dict = old_checkpoint
|
||||
|
||||
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
|
||||
|
||||
# Convert parameter names if needed
|
||||
model_depth = int(config.get("depth", 16))
|
||||
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
|
||||
|
||||
# Create transformer with config
|
||||
print("Creating PhotonTransformer2DModel...")
|
||||
transformer = PhotonTransformer2DModel(**config)
|
||||
|
||||
# Load state dict
|
||||
print("Loading converted parameters...")
|
||||
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"⚠ Missing keys: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
print(f"⚠ Unexpected keys: {unexpected_keys}")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✓ All parameters loaded successfully!")
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
def create_scheduler_config(output_path: str, shift: float):
|
||||
"""Create FlowMatchEulerDiscreteScheduler config."""
|
||||
|
||||
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
|
||||
|
||||
scheduler_path = os.path.join(output_path, "scheduler")
|
||||
os.makedirs(scheduler_path, exist_ok=True)
|
||||
|
||||
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
|
||||
json.dump(scheduler_config, f, indent=2)
|
||||
|
||||
print("✓ Created scheduler config")
|
||||
|
||||
|
||||
def download_and_save_vae(vae_type: str, output_path: str):
|
||||
"""Download and save VAE to local directory."""
|
||||
from diffusers import AutoencoderDC, AutoencoderKL
|
||||
|
||||
vae_path = os.path.join(output_path, "vae")
|
||||
os.makedirs(vae_path, exist_ok=True)
|
||||
|
||||
if vae_type == "flux":
|
||||
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
|
||||
else: # dc-ae
|
||||
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
|
||||
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
|
||||
|
||||
vae.save_pretrained(vae_path)
|
||||
print(f"✓ Saved VAE to {vae_path}")
|
||||
|
||||
|
||||
def download_and_save_text_encoder(output_path: str):
|
||||
"""Download and save T5Gemma text encoder and tokenizer."""
|
||||
from transformers import GemmaTokenizerFast
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
|
||||
|
||||
text_encoder_path = os.path.join(output_path, "text_encoder")
|
||||
tokenizer_path = os.path.join(output_path, "tokenizer")
|
||||
os.makedirs(text_encoder_path, exist_ok=True)
|
||||
os.makedirs(tokenizer_path, exist_ok=True)
|
||||
|
||||
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
|
||||
# Extract and save only the encoder
|
||||
t5gemma_encoder = t5gemma_model.encoder
|
||||
t5gemma_encoder.save_pretrained(text_encoder_path)
|
||||
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
|
||||
|
||||
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
print(f"✓ Saved tokenizer to {tokenizer_path}")
|
||||
|
||||
|
||||
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
|
||||
"""Create model_index.json for the pipeline."""
|
||||
|
||||
if vae_type == "flux":
|
||||
vae_class = "AutoencoderKL"
|
||||
else: # dc-ae
|
||||
vae_class = "AutoencoderDC"
|
||||
|
||||
model_index = {
|
||||
"_class_name": "PhotonPipeline",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"_name_or_path": os.path.basename(output_path),
|
||||
"default_sample_size": default_image_size,
|
||||
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
||||
"text_encoder": ["photon", "T5GemmaEncoder"],
|
||||
"tokenizer": ["transformers", "GemmaTokenizerFast"],
|
||||
"transformer": ["diffusers", "PhotonTransformer2DModel"],
|
||||
"vae": ["diffusers", vae_class],
|
||||
}
|
||||
|
||||
model_index_path = os.path.join(output_path, "model_index.json")
|
||||
with open(model_index_path, "w") as f:
|
||||
json.dump(model_index, f, indent=2)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Validate inputs
|
||||
if not os.path.exists(args.checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
|
||||
|
||||
config = build_config(args.vae_type)
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
print(f"✓ Output directory: {args.output_path}")
|
||||
|
||||
# Create transformer from checkpoint
|
||||
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
|
||||
|
||||
# Save transformer
|
||||
transformer_path = os.path.join(args.output_path, "transformer")
|
||||
os.makedirs(transformer_path, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
with open(os.path.join(transformer_path, "config.json"), "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save model weights as safetensors
|
||||
state_dict = transformer.state_dict()
|
||||
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
|
||||
print(f"✓ Saved transformer to {transformer_path}")
|
||||
|
||||
# Create scheduler config
|
||||
create_scheduler_config(args.output_path, args.shift)
|
||||
|
||||
download_and_save_vae(args.vae_type, args.output_path)
|
||||
download_and_save_text_encoder(args.output_path)
|
||||
|
||||
# Create model_index.json
|
||||
create_model_index(args.vae_type, args.resolution, args.output_path)
|
||||
|
||||
# Verify the pipeline can be loaded
|
||||
try:
|
||||
pipeline = PhotonPipeline.from_pretrained(args.output_path)
|
||||
print("Pipeline loaded successfully!")
|
||||
print(f"Transformer: {type(pipeline.transformer).__name__}")
|
||||
print(f"VAE: {type(pipeline.vae).__name__}")
|
||||
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
|
||||
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
|
||||
|
||||
# Display model info
|
||||
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
|
||||
print(f"✓ Transformer parameters: {num_params:,}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Pipeline verification failed: {e}")
|
||||
return False
|
||||
|
||||
print("Conversion completed successfully!")
|
||||
print(f"Converted pipeline saved to: {args.output_path}")
|
||||
print(f"VAE type: {args.vae_type}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vae_type",
|
||||
type=str,
|
||||
choices=["flux", "dc-ae"],
|
||||
required=True,
|
||||
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
choices=[256, 512, 1024],
|
||||
default=DEFAULT_RESOLUTION,
|
||||
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Shift for the scheduler",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
success = main(args)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Conversion failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -232,6 +232,7 @@ else:
|
||||
"MultiControlNetModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
"PhotonTransformer2DModel",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"QwenImageControlNetModel",
|
||||
@@ -515,6 +516,7 @@ else:
|
||||
"MusicLDMPipeline",
|
||||
"OmniGenPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"PhotonPipeline",
|
||||
"PIAPipeline",
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
@@ -926,6 +928,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiControlNetModel,
|
||||
OmniGenTransformer2DModel,
|
||||
ParallelConfig,
|
||||
PhotonTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageControlNetModel,
|
||||
@@ -1179,6 +1182,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MusicLDMPipeline,
|
||||
OmniGenPipeline,
|
||||
PaintByExamplePipeline,
|
||||
PhotonPipeline,
|
||||
PIAPipeline,
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
|
||||
@@ -96,6 +96,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
@@ -190,6 +191,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
PhotonTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageTransformer2DModel,
|
||||
|
||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_photon import PhotonTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
|
||||
768
src/diffusers/models/transformers/transformer_photon.py
Normal file
768
src/diffusers/models/transformers/transformer_photon.py
Normal file
@@ -0,0 +1,768 @@
|
||||
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.functional import fold, unfold
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..embeddings import get_timestep_embedding
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
|
||||
r"""
|
||||
Generates 2D patch coordinate indices for a batch of images.
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
Number of images in the batch.
|
||||
height (`int`):
|
||||
Height of the input images (in pixels).
|
||||
width (`int`):
|
||||
Width of the input images (in pixels).
|
||||
patch_size (`int`):
|
||||
Size of the square patches that the image is divided into.
|
||||
device (`torch.device`):
|
||||
The device on which to create the tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
|
||||
image grid.
|
||||
"""
|
||||
|
||||
img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
|
||||
img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
|
||||
img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
|
||||
return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
|
||||
def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Applies rotary positional embeddings (RoPE) to a query tensor.
|
||||
|
||||
Args:
|
||||
xq (`torch.Tensor`):
|
||||
Input tensor of shape `(..., dim)` representing the queries.
|
||||
freqs_cis (`torch.Tensor`):
|
||||
Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Tensor of the same shape as `xq` with rotary embeddings applied.
|
||||
"""
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
# Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
|
||||
freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq)
|
||||
|
||||
|
||||
class PhotonAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention
|
||||
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "PhotonAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply Photon attention using PhotonAttention module.
|
||||
|
||||
Args:
|
||||
attn: PhotonAttention module containing projection layers
|
||||
hidden_states: Image tokens [B, L_img, D]
|
||||
encoder_hidden_states: Text tokens [B, L_txt, D]
|
||||
attention_mask: Boolean mask for text tokens [B, L_txt]
|
||||
image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
|
||||
"""
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
|
||||
|
||||
# Project image tokens to Q, K, V
|
||||
img_qkv = attn.img_qkv_proj(hidden_states)
|
||||
B, L_img, _ = img_qkv.shape
|
||||
img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim)
|
||||
img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D]
|
||||
img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
|
||||
|
||||
# Apply QK normalization to image tokens
|
||||
img_q = attn.norm_q(img_q)
|
||||
img_k = attn.norm_k(img_k)
|
||||
|
||||
# Project text tokens to K, V
|
||||
txt_kv = attn.txt_kv_proj(encoder_hidden_states)
|
||||
B, L_txt, _ = txt_kv.shape
|
||||
txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim)
|
||||
txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D]
|
||||
txt_k, txt_v = txt_kv[0], txt_kv[1]
|
||||
|
||||
# Apply K normalization to text tokens
|
||||
txt_k = attn.norm_added_k(txt_k)
|
||||
|
||||
# Apply RoPE to image queries and keys
|
||||
if image_rotary_emb is not None:
|
||||
img_q = apply_rope(img_q, image_rotary_emb)
|
||||
img_k = apply_rope(img_k, image_rotary_emb)
|
||||
|
||||
# Concatenate text and image keys/values
|
||||
k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D]
|
||||
v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D]
|
||||
|
||||
# Build attention mask if provided
|
||||
attn_mask_tensor = None
|
||||
if attention_mask is not None:
|
||||
bs, _, l_img, _ = img_q.shape
|
||||
l_txt = txt_k.shape[2]
|
||||
|
||||
if attention_mask.dim() != 2:
|
||||
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
|
||||
if attention_mask.shape[-1] != l_txt:
|
||||
raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")
|
||||
|
||||
device = img_q.device
|
||||
ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
|
||||
attention_mask = attention_mask.to(device=device, dtype=torch.bool)
|
||||
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
|
||||
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
|
||||
|
||||
# Apply attention using dispatch_attention_fn for backend support
|
||||
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
|
||||
query = img_q.transpose(1, 2) # [B, L_img, H, D]
|
||||
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
|
||||
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
|
||||
|
||||
attn_output = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask_tensor,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape from [B, L_img, H, D] to [B, L_img, H*D]
|
||||
batch_size, seq_len, num_heads, head_dim = attn_output.shape
|
||||
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
|
||||
|
||||
# Apply output projection
|
||||
attn_output = attn.to_out[0](attn_output)
|
||||
if len(attn.to_out) > 1:
|
||||
attn_output = attn.to_out[1](attn_output) # dropout if present
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class PhotonAttention(nn.Module, AttentionModuleMixin):
|
||||
r"""
|
||||
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
|
||||
Photon's architecture.
|
||||
"""
|
||||
|
||||
_default_processor_cls = PhotonAttnProcessor2_0
|
||||
_available_processors = [PhotonAttnProcessor2_0]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
bias: bool = False,
|
||||
out_bias: bool = False,
|
||||
eps: float = 1e-6,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.heads = heads
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
|
||||
self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias)
|
||||
|
||||
self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias)
|
||||
self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(0.0))
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class PhotonEmbedND(nn.Module):
|
||||
r"""
|
||||
N-dimensional rotary positional embedding.
|
||||
|
||||
This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
|
||||
dimension. The embeddings are combined and returned as a single tensor
|
||||
|
||||
Args:
|
||||
dim (int):
|
||||
Base embedding dimension (must be even).
|
||||
theta (int):
|
||||
Scaling factor that controls the frequency spectrum of the rotary embeddings.
|
||||
axes_dim (list[int]):
|
||||
List of embedding dimensions for each axis (each must be even).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
# Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
|
||||
# out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
|
||||
out = out.reshape(*out.shape[:-1], 2, 2)
|
||||
return out.float()
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
r"""
|
||||
A simple 2-layer MLP used for embedding inputs.
|
||||
|
||||
Args:
|
||||
in_dim (`int`):
|
||||
Dimensionality of the input features.
|
||||
hidden_dim (`int`):
|
||||
Dimensionality of the hidden and output embedding space.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Tensor of shape `(..., hidden_dim)` containing the embedded representations.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
r"""
|
||||
Modulation network that generates scale, shift, and gating parameters.
|
||||
|
||||
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
|
||||
two tuples `(shift, scale, gate)`.
|
||||
|
||||
Args:
|
||||
dim (`int`):
|
||||
Dimensionality of the input vector. The output will have `6 * dim` features internally.
|
||||
|
||||
Returns:
|
||||
((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
|
||||
Two tuples `(shift, scale, gate)`.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.lin = nn.Linear(dim, 6 * dim, bias=True)
|
||||
nn.init.constant_(self.lin.weight, 0)
|
||||
nn.init.constant_(self.lin.bias, 0)
|
||||
|
||||
def forward(self, vec: torch.Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
|
||||
return tuple(out[:3]), tuple(out[3:])
|
||||
|
||||
|
||||
class PhotonBlock(nn.Module):
|
||||
r"""
|
||||
Multimodal transformer block with text–image cross-attention, modulation, and MLP.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Dimension of the hidden representations.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
||||
Expansion ratio for the hidden dimension inside the MLP.
|
||||
qk_scale (`float`, *optional*):
|
||||
Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.
|
||||
|
||||
Attributes:
|
||||
img_pre_norm (`nn.LayerNorm`):
|
||||
Pre-normalization applied to image tokens before attention.
|
||||
attention (`PhotonAttention`):
|
||||
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
|
||||
image and text tokens.
|
||||
post_attention_layernorm (`nn.LayerNorm`):
|
||||
Normalization applied after attention.
|
||||
gate_proj / up_proj / down_proj (`nn.Linear`):
|
||||
Feedforward layers forming the gated MLP.
|
||||
mlp_act (`nn.GELU`):
|
||||
Nonlinear activation used in the MLP.
|
||||
modulation (`Modulation`):
|
||||
Produces scale/shift/gating parameters for modulated layers.
|
||||
|
||||
Methods:
|
||||
The forward method performs cross-attention and the MLP with modulation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Pre-attention normalization for image tokens
|
||||
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
# PhotonAttention module with built-in projections and norms
|
||||
self.attention = PhotonAttention(
|
||||
query_dim=hidden_size,
|
||||
heads=num_heads,
|
||||
dim_head=self.head_dim,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
eps=1e-6,
|
||||
processor=PhotonAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# mlp
|
||||
self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
self.modulation = Modulation(hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Runs modulation-gated cross-attention and MLP, with residual connections.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Image tokens of shape `(B, L_img, hidden_size)`.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Text tokens of shape `(B, L_txt, hidden_size)`.
|
||||
temb (`torch.Tensor`):
|
||||
Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
|
||||
broadcastable).
|
||||
image_rotary_emb (`torch.Tensor`):
|
||||
Rotary positional embeddings applied inside attention.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
|
||||
**kwargs:
|
||||
Additional keyword arguments for API compatibility.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Updated image tokens of shape `(B, L_img, hidden_size)`.
|
||||
"""
|
||||
|
||||
mod_attn, mod_mlp = self.modulation(temb)
|
||||
attn_shift, attn_scale, attn_gate = mod_attn
|
||||
mlp_shift, mlp_scale, mlp_gate = mod_mlp
|
||||
|
||||
hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
|
||||
|
||||
attn_out = self.attention(
|
||||
hidden_states=hidden_states_mod,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + attn_gate * attn_out
|
||||
|
||||
x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
|
||||
hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
r"""
|
||||
Final projection layer with adaptive LayerNorm modulation.
|
||||
|
||||
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
|
||||
outputs.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Dimensionality of the input tokens.
|
||||
patch_size (`int`):
|
||||
Size of the square image patches.
|
||||
out_channels (`int`):
|
||||
Number of output channels per pixel (e.g. RGB = 3).
|
||||
|
||||
Forward Inputs:
|
||||
x (`torch.Tensor`):
|
||||
Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
|
||||
vec (`torch.Tensor`):
|
||||
Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
|
||||
LayerNorm.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
|
||||
r"""
|
||||
Flattens an image tensor into a sequence of non-overlapping patches.
|
||||
|
||||
Args:
|
||||
img (`torch.Tensor`):
|
||||
Input image tensor of shape `(B, C, H, W)`.
|
||||
patch_size (`int`):
|
||||
Size of each square patch. Must evenly divide both `H` and `W`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
|
||||
// patch_size)` is the number of patches.
|
||||
"""
|
||||
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
|
||||
|
||||
|
||||
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
|
||||
|
||||
Args:
|
||||
seq (`torch.Tensor`):
|
||||
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
|
||||
patch_size)`.
|
||||
patch_size (`int`):
|
||||
Size of each square patch.
|
||||
shape (`tuple` or `torch.Tensor`):
|
||||
The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
|
||||
height and width.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Reconstructed image tensor of shape `(B, C, H, W)`.
|
||||
"""
|
||||
if isinstance(shape, tuple):
|
||||
shape = shape[-2:]
|
||||
elif isinstance(shape, torch.Tensor):
|
||||
shape = (int(shape[0]), int(shape[1]))
|
||||
else:
|
||||
raise NotImplementedError(f"shape type {type(shape)} not supported")
|
||||
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
|
||||
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
Transformer-based 2D model for text to image generation.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, *optional*, defaults to 16):
|
||||
Number of input channels in the latent image.
|
||||
patch_size (`int`, *optional*, defaults to 2):
|
||||
Size of the square patches used to flatten the input image.
|
||||
context_in_dim (`int`, *optional*, defaults to 2304):
|
||||
Dimensionality of the text conditioning input.
|
||||
hidden_size (`int`, *optional*, defaults to 1792):
|
||||
Dimension of the hidden representation.
|
||||
mlp_ratio (`float`, *optional*, defaults to 3.5):
|
||||
Expansion ratio for the hidden dimension inside MLP blocks.
|
||||
num_heads (`int`, *optional*, defaults to 28):
|
||||
Number of attention heads.
|
||||
depth (`int`, *optional*, defaults to 16):
|
||||
Number of transformer blocks.
|
||||
axes_dim (`list[int]`, *optional*):
|
||||
List of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
|
||||
theta (`int`, *optional*, defaults to 10000):
|
||||
Frequency scaling factor for rotary embeddings.
|
||||
time_factor (`float`, *optional*, defaults to 1000.0):
|
||||
Scaling factor applied in timestep embeddings.
|
||||
time_max_period (`int`, *optional*, defaults to 10000):
|
||||
Maximum frequency period for timestep embeddings.
|
||||
|
||||
Attributes:
|
||||
pe_embedder (`EmbedND`):
|
||||
Multi-axis rotary embedding generator for positional encodings.
|
||||
img_in (`nn.Linear`):
|
||||
Projection layer for image patch tokens.
|
||||
time_in (`MLPEmbedder`):
|
||||
Embedding layer for timestep embeddings.
|
||||
txt_in (`nn.Linear`):
|
||||
Projection layer for text conditioning.
|
||||
blocks (`nn.ModuleList`):
|
||||
Stack of transformer blocks (`PhotonBlock`).
|
||||
final_layer (`LastLayer`):
|
||||
Projection layer mapping hidden tokens back to patch outputs.
|
||||
|
||||
Methods:
|
||||
attn_processors:
|
||||
Returns a dictionary of all attention processors in the model.
|
||||
set_attn_processor(processor):
|
||||
Replaces attention processors across all attention layers.
|
||||
process_inputs(image_latent, txt):
|
||||
Converts inputs into patch tokens, encodes text, and produces positional encodings.
|
||||
compute_timestep_embedding(timestep, dtype):
|
||||
Creates a timestep embedding of dimension 256, scaled and projected.
|
||||
forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
|
||||
**block_kwargs):
|
||||
Runs the sequence of transformer blocks over image and text tokens.
|
||||
forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
|
||||
attention_kwargs=None, return_dict=True):
|
||||
Full forward pass from latent input to reconstructed output image.
|
||||
|
||||
Returns:
|
||||
`Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
|
||||
- `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
patch_size: int = 2,
|
||||
context_in_dim: int = 2304,
|
||||
hidden_size: int = 1792,
|
||||
mlp_ratio: float = 3.5,
|
||||
num_heads: int = 28,
|
||||
depth: int = 16,
|
||||
axes_dim: list = None,
|
||||
theta: int = 10000,
|
||||
time_factor: float = 1000.0,
|
||||
time_max_period: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if axes_dim is None:
|
||||
axes_dim = [32, 32]
|
||||
|
||||
# Store parameters directly
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = self.in_channels * self.patch_size**2
|
||||
|
||||
self.time_factor = time_factor
|
||||
self.time_max_period = time_max_period
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
|
||||
|
||||
pe_dim = hidden_size // num_heads
|
||||
|
||||
if sum(axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PhotonBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return self.time_in(
|
||||
get_timestep_embedding(
|
||||
timesteps=timestep,
|
||||
embedding_dim=256,
|
||||
max_period=self.time_max_period,
|
||||
scale=self.time_factor,
|
||||
flip_sin_to_cos=True, # Match original cos, sin order
|
||||
).to(dtype)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
r"""
|
||||
Forward pass of the PhotonTransformer2DModel.
|
||||
|
||||
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
|
||||
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input latent image tensor of shape `(B, C, H, W)`.
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional arguments passed to attention layers.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a `Transformer2DModelOutput` or a tuple.
|
||||
|
||||
Returns:
|
||||
`Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:
|
||||
|
||||
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
|
||||
"""
|
||||
# Process text conditioning
|
||||
txt = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Convert image to sequence and embed
|
||||
img = img2seq(hidden_states, self.patch_size)
|
||||
img = self.img_in(img)
|
||||
|
||||
# Generate positional embeddings
|
||||
bs, _, h, w = hidden_states.shape
|
||||
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
|
||||
pe = self.pe_embedder(img_ids)
|
||||
|
||||
# Compute time embedding
|
||||
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
|
||||
|
||||
# Apply transformer blocks
|
||||
for block in self.blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
img = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
img = block(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
temb=vec,
|
||||
image_rotary_emb=pe,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
# Final layer and convert back to image
|
||||
img = self.final_layer(img, vec)
|
||||
output = seq2img(img, self.patch_size, hidden_states.shape)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -144,6 +144,7 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxKontextInpaintPipeline",
|
||||
]
|
||||
_import_structure["photon"] = ["PhotonPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
@@ -717,6 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .photon import PhotonPipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .qwenimage import (
|
||||
|
||||
63
src/diffusers/pipelines/photon/__init__.py
Normal file
63
src/diffusers/pipelines/photon/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_photon"] = ["PhotonPipeline"]
|
||||
|
||||
# Import T5GemmaEncoder for pipeline loading compatibility
|
||||
try:
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
|
||||
|
||||
_additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
|
||||
# Patch transformers module directly for serialization
|
||||
if not hasattr(transformers, "T5GemmaEncoder"):
|
||||
transformers.T5GemmaEncoder = T5GemmaEncoder
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_output import PhotonPipelineOutput
|
||||
from .pipeline_photon import PhotonPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
35
src/diffusers/pipelines/photon/pipeline_output.py
Normal file
35
src/diffusers/pipelines/photon/pipeline_output.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhotonPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Photon pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
768
src/diffusers/pipelines/photon/pipeline_photon.py
Normal file
768
src/diffusers/pipelines/photon/pipeline_photon.py
Normal file
@@ -0,0 +1,768 @@
|
||||
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
GemmaTokenizerFast,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
|
||||
|
||||
from diffusers.image_processor import PixArtImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderDC, AutoencoderKL
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
ASPECT_RATIO_256_BIN = {
|
||||
"0.46": [160, 352],
|
||||
"0.6": [192, 320],
|
||||
"0.78": [224, 288],
|
||||
"1.0": [256, 256],
|
||||
"1.29": [288, 224],
|
||||
"1.67": [320, 192],
|
||||
"2.2": [352, 160],
|
||||
}
|
||||
|
||||
ASPECT_RATIO_512_BIN = {
|
||||
"0.5": [352, 704],
|
||||
"0.57": [384, 672],
|
||||
"0.6": [384, 640],
|
||||
"0.68": [416, 608],
|
||||
"0.78": [448, 576],
|
||||
"0.88": [480, 544],
|
||||
"1.0": [512, 512],
|
||||
"1.13": [544, 480],
|
||||
"1.29": [576, 448],
|
||||
"1.46": [608, 416],
|
||||
"1.67": [640, 384],
|
||||
"1.75": [672, 384],
|
||||
"2.0": [704, 352],
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
"""Text preprocessing utility for PhotonPipeline."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize text preprocessor."""
|
||||
self.bad_punct_regex = re.compile(
|
||||
r"["
|
||||
+ "#®•©™&@·º½¾¿¡§~"
|
||||
+ r"\)"
|
||||
+ r"\("
|
||||
+ r"\]"
|
||||
+ r"\["
|
||||
+ r"\}"
|
||||
+ r"\{"
|
||||
+ r"\|"
|
||||
+ r"\\"
|
||||
+ r"\/"
|
||||
+ r"\*"
|
||||
+ r"]{1,}"
|
||||
)
|
||||
|
||||
def clean_text(self, text: str) -> str:
|
||||
"""Clean text using comprehensive text processing logic."""
|
||||
# See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
|
||||
text = str(text)
|
||||
text = ul.unquote_plus(text)
|
||||
text = text.strip().lower()
|
||||
text = re.sub("<person>", "person", text)
|
||||
|
||||
# Remove all urls:
|
||||
text = re.sub(
|
||||
r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))",
|
||||
"",
|
||||
text,
|
||||
) # regex for urls
|
||||
|
||||
# @<nickname>
|
||||
text = re.sub(r"@[\w\d]+\b", "", text)
|
||||
|
||||
# 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs
|
||||
text = re.sub(r"[\u31c0-\u31ef]+", "", text)
|
||||
text = re.sub(r"[\u31f0-\u31ff]+", "", text)
|
||||
text = re.sub(r"[\u3200-\u32ff]+", "", text)
|
||||
text = re.sub(r"[\u3300-\u33ff]+", "", text)
|
||||
text = re.sub(r"[\u3400-\u4dbf]+", "", text)
|
||||
text = re.sub(r"[\u4dc0-\u4dff]+", "", text)
|
||||
text = re.sub(r"[\u4e00-\u9fff]+", "", text)
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
text = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
|
||||
"-",
|
||||
text,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
text = re.sub(r"[`´«»" "¨]", '"', text)
|
||||
text = re.sub(r"['']", "'", text)
|
||||
|
||||
# " and &
|
||||
text = re.sub(r""?", "", text)
|
||||
text = re.sub(r"&", "", text)
|
||||
|
||||
# ip addresses:
|
||||
text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text)
|
||||
|
||||
# article ids:
|
||||
text = re.sub(r"\d:\d\d\s+$", "", text)
|
||||
|
||||
# \n
|
||||
text = re.sub(r"\\n", " ", text)
|
||||
|
||||
# "#123", "#12345..", "123456.."
|
||||
text = re.sub(r"#\d{1,3}\b", "", text)
|
||||
text = re.sub(r"#\d{5,}\b", "", text)
|
||||
text = re.sub(r"\b\d{6,}\b", "", text)
|
||||
|
||||
# filenames:
|
||||
text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text)
|
||||
|
||||
# Clean punctuation
|
||||
text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT"""
|
||||
text = re.sub(r"[\.]{2,}", r" ", text)
|
||||
|
||||
text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
text = re.sub(r"\s+\.\s+", r" ", text) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, text)) > 3:
|
||||
text = re.sub(regex2, " ", text)
|
||||
|
||||
# Basic cleaning
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
text = text.strip()
|
||||
|
||||
# Clean alphanumeric patterns
|
||||
text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640
|
||||
text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc
|
||||
text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231
|
||||
|
||||
# Common spam patterns
|
||||
text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text)
|
||||
text = re.sub(r"(free\s)?download(\sfree)?", "", text)
|
||||
text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text)
|
||||
text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text)
|
||||
text = re.sub(r"\bpage\s+\d+\b", "", text)
|
||||
|
||||
text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a...
|
||||
text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text)
|
||||
|
||||
# Final cleanup
|
||||
text = re.sub(r"\b\s+\:\s+", r": ", text)
|
||||
text = re.sub(r"(\D[,\./])\b", r"\1 ", text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
|
||||
text.strip()
|
||||
|
||||
text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text)
|
||||
text = re.sub(r"^[\'\_,\-\:;]", r"", text)
|
||||
text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text)
|
||||
text = re.sub(r"^\.\S+$", "", text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import PhotonPipeline
|
||||
|
||||
>>> # Load pipeline with from_pretrained
|
||||
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft")
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
|
||||
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
>>> image.save("photon_output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class PhotonPipeline(
|
||||
DiffusionPipeline,
|
||||
LoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Photon Transformer.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
transformer ([`PhotonTransformer2DModel`]):
|
||||
The Photon transformer model to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
text_encoder ([`T5GemmaEncoder`]):
|
||||
Text encoder model for encoding prompts.
|
||||
tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]):
|
||||
Tokenizer for the text encoder.
|
||||
vae ([`AutoencoderKL`] or [`AutoencoderDC`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
_optional_components = ["vae"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: PhotonTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder: T5GemmaEncoder,
|
||||
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
|
||||
vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None,
|
||||
default_sample_size: Optional[int] = DEFAULT_RESOLUTION,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if PhotonTransformer2DModel is None:
|
||||
raise ImportError(
|
||||
"PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed."
|
||||
)
|
||||
|
||||
self.text_preprocessor = TextPreprocessor()
|
||||
self.default_sample_size = default_sample_size
|
||||
self._guidance_scale = 1.0
|
||||
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
)
|
||||
|
||||
self.register_to_config(default_sample_size=self.default_sample_size)
|
||||
|
||||
if vae is not None:
|
||||
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
else:
|
||||
self.image_processor = None
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
if self.vae is None:
|
||||
return 8
|
||||
if hasattr(self.vae, "spatial_compression_ratio"):
|
||||
return self.vae.spatial_compression_ratio
|
||||
else: # Flux VAE
|
||||
return 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
"""Check if classifier-free guidance is enabled based on guidance scale."""
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
def get_default_resolution(self):
|
||||
"""Determine the default resolution based on the loaded VAE and config.
|
||||
|
||||
Returns:
|
||||
int: The default sample size (height/width) to use for generation.
|
||||
"""
|
||||
default_from_config = getattr(self.config, "default_sample_size", None)
|
||||
if default_from_config is not None:
|
||||
return default_from_config
|
||||
|
||||
return DEFAULT_RESOLUTION
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Prepare initial latents for the diffusion process."""
|
||||
if latents is None:
|
||||
spatial_compression = self.vae_scale_factor
|
||||
latent_height, latent_width = (
|
||||
height // spatial_compression,
|
||||
width // spatial_compression,
|
||||
)
|
||||
shape = (batch_size, num_channels_latents, latent_height, latent_width)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
):
|
||||
"""Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings."""
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt_embeds is None:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
# Encode the prompts
|
||||
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
|
||||
self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
|
||||
)
|
||||
|
||||
# Duplicate embeddings for each generation per prompt
|
||||
if num_images_per_prompt > 1:
|
||||
# Repeat prompt embeddings
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if prompt_attention_mask is not None:
|
||||
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
# Repeat negative embeddings if using CFG
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is not None:
|
||||
bs_embed, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if negative_prompt_attention_mask is not None:
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds if do_classifier_free_guidance else None,
|
||||
negative_prompt_attention_mask if do_classifier_free_guidance else None,
|
||||
)
|
||||
|
||||
def _tokenize_prompts(self, prompts: List[str], device: torch.device):
|
||||
"""Tokenize and clean prompts."""
|
||||
cleaned = [self.text_preprocessor.clean_text(text) for text in prompts]
|
||||
tokens = self.tokenizer(
|
||||
cleaned,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device)
|
||||
|
||||
def _encode_prompt_standard(
|
||||
self,
|
||||
prompt: List[str],
|
||||
device: torch.device,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
):
|
||||
"""Encode prompt using standard text encoder and tokenizer with batch processing."""
|
||||
batch_size = len(prompt)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
|
||||
prompts_to_encode = negative_prompt + prompt
|
||||
else:
|
||||
prompts_to_encode = prompt
|
||||
|
||||
input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device)
|
||||
|
||||
with torch.no_grad():
|
||||
embeddings = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["last_hidden_state"]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0)
|
||||
uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0)
|
||||
else:
|
||||
text_embeddings = embeddings
|
||||
cross_attn_mask = attention_mask
|
||||
uncond_text_embeddings = None
|
||||
uncond_cross_attn_mask = None
|
||||
|
||||
return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int,
|
||||
width: int,
|
||||
guidance_scale: float,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
"""Check that all inputs are in correct format."""
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided and `guidance_scale > 1.0`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
|
||||
spatial_compression = self.vae_scale_factor
|
||||
if height % spatial_compression != 0 or width % spatial_compression != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if guidance_scale < 1.0:
|
||||
raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead.
|
||||
negative_prompt (`str`, *optional*, defaults to `""`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 28):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an
|
||||
empty string.
|
||||
prompt_attention_mask (`torch.BoolTensor`, *optional*):
|
||||
Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated
|
||||
from `prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.BoolTensor`, *optional*):
|
||||
Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`,
|
||||
attention mask will be generated from an empty string.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple.
|
||||
use_resolution_binning (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
|
||||
to the requested resolution. Useful for generating non-square images at optimal resolutions.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`.
|
||||
`callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed
|
||||
in the `._callback_tensor_inputs` attribute.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
# 0. Set height and width
|
||||
default_resolution = self.get_default_resolution()
|
||||
height = height or default_resolution
|
||||
width = width or default_resolution
|
||||
|
||||
if use_resolution_binning:
|
||||
if self.image_processor is None:
|
||||
raise ValueError(
|
||||
"Resolution binning requires a VAE with image_processor, but VAE is not available. "
|
||||
"Set use_resolution_binning=False or provide a VAE."
|
||||
)
|
||||
if self.default_sample_size <= 256:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||
|
||||
# Store original dimensions
|
||||
orig_height, orig_width = height, width
|
||||
# Map to closest resolution in the bin
|
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
if self.vae is None and output_type not in ["latent", "pt"]:
|
||||
raise ValueError(
|
||||
f"VAE is required for output_type='{output_type}' but it is not available. "
|
||||
"Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs."
|
||||
)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Use execution device (handles offloading scenarios including group offloading)
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 2. Encode input prompt
|
||||
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
# Expose standard names for callbacks parity
|
||||
prompt_embeds = text_embeddings
|
||||
negative_prompt_embeds = uncond_text_embeddings
|
||||
|
||||
# 3. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
self.num_timesteps = len(timesteps)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
if self.vae is not None:
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
else:
|
||||
# When vae is None, get latent channels from transformer
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare extra step kwargs
|
||||
extra_step_kwargs = {}
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = 0.0
|
||||
|
||||
# 6. Prepare cross-attention embeddings and masks
|
||||
if self.do_classifier_free_guidance:
|
||||
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
|
||||
ca_mask = None
|
||||
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
|
||||
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
|
||||
else:
|
||||
ca_embed = text_embeddings
|
||||
ca_mask = cross_attn_mask
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Duplicate latents if using classifier-free guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
latents_in = torch.cat([latents, latents], dim=0)
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
|
||||
else:
|
||||
latents_in = latents
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
|
||||
|
||||
# Forward through transformer
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents_in,
|
||||
timestep=t_cont,
|
||||
encoder_hidden_states=ca_embed,
|
||||
attention_mask=ca_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond, noise_text = noise_pred.chunk(2, dim=0)
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
# Call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# 8. Post-processing
|
||||
if output_type == "latent" or (output_type == "pt" and self.vae is None):
|
||||
image = latents
|
||||
else:
|
||||
# Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)
|
||||
scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
|
||||
shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
|
||||
latents = (latents / scaling_factor) + shift_factor
|
||||
# Decode using VAE (AutoencoderKL or AutoencoderDC)
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
# Resize back to original resolution if using binning
|
||||
if use_resolution_binning:
|
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
|
||||
# Use standard image processor for post-processing
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return PhotonPipelineOutput(images=image)
|
||||
@@ -1098,6 +1098,21 @@ class ParallelConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PhotonTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1847,6 +1847,21 @@ class PaintByExamplePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PhotonPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PIAPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
83
tests/models/transformers/test_models_transformer_photon.py
Normal file
83
tests/models/transformers/test_models_transformer_photon.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = PhotonTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 16, 16)
|
||||
|
||||
def prepare_dummy_input(self, height=16, width=16):
|
||||
batch_size = 1
|
||||
num_latent_channels = 16
|
||||
sequence_length = 16
|
||||
embedding_dim = 1792
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 16,
|
||||
"patch_size": 2,
|
||||
"context_in_dim": 1792,
|
||||
"hidden_size": 1792,
|
||||
"mlp_ratio": 3.5,
|
||||
"num_heads": 28,
|
||||
"depth": 4, # Smaller depth for testing
|
||||
"axes_dim": [32, 32],
|
||||
"theta": 10_000,
|
||||
}
|
||||
inputs_dict = self.prepare_dummy_input()
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"PhotonTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
tests/pipelines/photon/__init__.py
Normal file
0
tests/pipelines/photon/__init__.py
Normal file
265
tests/pipelines/photon/test_pipeline_photon.py
Normal file
265
tests/pipelines/photon/test_pipeline_photon.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
|
||||
|
||||
from diffusers.models import AutoencoderDC, AutoencoderKL
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">", "4.57.1"),
|
||||
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
|
||||
strict=False,
|
||||
)
|
||||
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = PhotonPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Ensure PhotonPipeline has an _execution_device property expected by __call__
|
||||
if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property):
|
||||
try:
|
||||
setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = PhotonTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
context_in_dim=8,
|
||||
hidden_size=8,
|
||||
mlp_ratio=2.0,
|
||||
num_heads=2,
|
||||
depth=1,
|
||||
axes_dim=[2, 2],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=4,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0,
|
||||
scaling_factor=1.0,
|
||||
).eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
tokenizer.model_max_length = 64
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
encoder_params = {
|
||||
"vocab_size": tokenizer.vocab_size,
|
||||
"hidden_size": 8,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 4,
|
||||
"max_position_embeddings": 64,
|
||||
"layer_types": ["full_attention"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"dropout_rate": 0.0,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"rms_norm_eps": 1e-06,
|
||||
"attn_logit_softcapping": 50.0,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"query_pre_attn_scalar": 4,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 4096,
|
||||
}
|
||||
encoder_config = T5GemmaModuleConfig(**encoder_params)
|
||||
text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
|
||||
text_encoder = T5GemmaEncoder(text_encoder_config)
|
||||
|
||||
return {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
return {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
"use_resolution_binning": False,
|
||||
}
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = PhotonPipeline(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
try:
|
||||
pipe.register_to_config(_execution_device="cpu")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs)[0]
|
||||
generated_image = image[0]
|
||||
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
expected_image = torch.zeros(3, 32, 32)
|
||||
max_diff = np.abs(generated_image - expected_image).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = PhotonPipeline(**components)
|
||||
pipe = pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
try:
|
||||
pipe.register_to_config(_execution_device="cpu")
|
||||
except Exception:
|
||||
pass
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs("cpu")
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def to_np_local(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return tensor.detach().cpu().numpy()
|
||||
return tensor
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max()
|
||||
self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
|
||||
|
||||
def test_inference_with_autoencoder_dc(self):
|
||||
"""Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae_dc = AutoencoderDC(
|
||||
in_channels=3,
|
||||
latent_channels=4,
|
||||
attention_head_dim=2,
|
||||
encoder_block_types=(
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
decoder_block_types=(
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
encoder_block_out_channels=(8, 8),
|
||||
decoder_block_out_channels=(8, 8),
|
||||
encoder_qkv_multiscales=((), (5,)),
|
||||
decoder_qkv_multiscales=((), (5,)),
|
||||
encoder_layers_per_block=(1, 1),
|
||||
decoder_layers_per_block=(1, 1),
|
||||
upsample_block_type="interpolate",
|
||||
downsample_block_type="stride_conv",
|
||||
decoder_norm_types="rms_norm",
|
||||
decoder_act_fns="silu",
|
||||
).eval()
|
||||
|
||||
components["vae"] = vae_dc
|
||||
|
||||
pipe = PhotonPipeline(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
expected_scale_factor = vae_dc.spatial_compression_ratio
|
||||
self.assertEqual(pipe.vae_scale_factor, expected_scale_factor)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs)[0]
|
||||
generated_image = image[0]
|
||||
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
expected_image = torch.zeros(3, 32, 32)
|
||||
max_diff = np.abs(generated_image - expected_image).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
Reference in New Issue
Block a user