Compare commits

...

50 Commits

Author SHA1 Message Date
sayakpaul
53a2a7aff5 fix nits. 2025-10-21 04:31:48 -10:00
David Bertoin
8de7b9247a Update tests/pipelines/photon/test_pipeline_photon.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-10-21 07:29:30 +00:00
David Bertoin
5c54baacb7 Update tests/pipelines/photon/test_pipeline_photon.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-10-21 07:29:30 +00:00
David Bertoin
9e8279e1fe restrict the version of transformers
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-10-21 07:29:30 +00:00
David Bertoin
aed1f19396 Use Tuple instead of tuple
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-10-21 07:29:30 +00:00
DavidBert
1354f450e6 make fix-copies 2025-10-21 07:29:30 +00:00
David Bertoin
fdc8e34533 Add PhotonTransformer2DModel to TYPE_CHECKING imports 2025-10-21 07:29:30 +00:00
David Bertoin
d5ffd35d70 Update docs/source/en/api/pipelines/photon.md
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-10-21 07:29:30 +00:00
David Bertoin
adeb45e0b3 make fix copy 2025-10-21 07:29:30 +00:00
DavidBert
7d12474c24 naming changes 2025-10-21 07:29:30 +00:00
DavidBert
0ef0dc6837 use dispatch_attention_fn for multiple attention backend support 2025-10-21 07:29:30 +00:00
David Bertoin
836cd12a18 Update docs/source/en/api/pipelines/photon.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-10-21 07:29:30 +00:00
David Bertoin
caf64407cb Update docs/source/en/api/pipelines/photon.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-10-21 07:29:30 +00:00
David Bertoin
c469a7a916 Update docs/source/en/api/pipelines/photon.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-10-21 07:29:30 +00:00
DavidBert
9aa47ce6c3 added doc to toctree 2025-10-21 07:29:30 +00:00
DavidBert
a8fa52ba2a quantization example 2025-10-21 07:27:15 +00:00
David Bertoin
34a74928ac Update docs/source/en/api/pipelines/photon.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-10-21 07:27:15 +00:00
DavidBert
0fdfd27ecb renaming and remove unecessary attributes setting 2025-10-21 07:27:15 +00:00
DavidBert
574f8fd10a parameter names match the standard diffusers conventions 2025-10-21 07:27:15 +00:00
DavidBert
6e05172682 Revert accidental .gitignore change 2025-10-21 07:27:15 +00:00
DavidBert
d0c029f15d built-in RMSNorm 2025-10-21 07:27:15 +00:00
DavidBert
015774399e Refactor PhotonAttention to match Flux pattern 2025-10-21 07:27:14 +00:00
DavidBert
5f99168def utility function that determines the default resolution given the VAE 2025-10-21 07:27:14 +00:00
DavidBert
c329c8f667 add pipeline test + corresponding fixes 2025-10-21 07:27:14 +00:00
David Bertoin
bb36735379 make quality + style 2025-10-21 07:27:14 +00:00
David Bertoin
8ee17d20b3 Use _import_structure for lazy loading 2025-10-21 07:27:14 +00:00
David Bertoin
be1d14658e add negative prompts 2025-10-21 07:27:14 +00:00
David Bertoin
c951adef45 move xattention conditionning out computation out of the denoising loop 2025-10-21 07:27:14 +00:00
David Bertoin
a74e0b726a support prompt_embeds in call 2025-10-21 07:27:14 +00:00
David Bertoin
ffe3501c1c rename vae_spatial_compression_ratio for vae_scale_factor 2025-10-21 07:27:14 +00:00
David Bertoin
2077252947 remove lora related code 2025-10-21 07:27:14 +00:00
David Bertoin
de1ceaf07a renam LastLayer for FinalLayer 2025-10-21 07:27:14 +00:00
David Bertoin
3c60c9230e put _attn_forward and _ffn_forward logic in PhotonBlock's forward 2025-10-21 07:27:14 +00:00
David Bertoin
af8882d7e6 remove modulation dataclass 2025-10-21 07:27:14 +00:00
David Bertoin
5f0bf0181f Rename EmbedND for PhotoEmbedND 2025-10-21 07:27:14 +00:00
davidb
ae44d845b6 remove lora support from doc 2025-10-21 07:27:14 +00:00
davidb
12dbabe607 fix timestep shift 2025-10-21 07:27:14 +00:00
davidb
ec70e3fdc0 fix T5Gemma loading from hub 2025-10-21 07:27:14 +00:00
davidb
6634113ef6 update doc 2025-10-21 07:27:14 +00:00
davidb
b07d1c8799 update doc 2025-10-21 07:27:14 +00:00
davidb
a9e301366a unify the structure of the forward block 2025-10-21 07:27:14 +00:00
davidb
5886925346 remove einops dependency and now inherits from AttentionMixin 2025-10-21 07:27:14 +00:00
davidb
25a0061d65 move PhotonAttnProcessor2_0 in transformer_photon 2025-10-21 07:27:14 +00:00
davidb
6284b9d062 remove enhance vae and use vae.config directly when possible 2025-10-21 07:27:14 +00:00
davidb
60d918d79b conditioned CFG 2025-10-21 07:27:14 +00:00
David Briand
b327b36ad9 BF16 example 2025-10-21 07:27:14 +00:00
davidb
14903ee599 remove autocast for text encoder forwad 2025-10-21 07:27:14 +00:00
davidb
d71ddd0079 enhance_vae_properties if vae is provided only 2025-10-21 07:27:14 +00:00
davidb
6a66fbd2c4 just store the T5Gemma encoder 2025-10-21 07:27:13 +00:00
davidb
e487660e05 Add Photon model and pipeline support
This commit adds support for the Photon image generation model:
- PhotonTransformer2DModel: Core transformer architecture
- PhotonPipeline: Text-to-image generation pipeline
- Attention processor updates for Photon-specific attention mechanism
- Conversion script for loading Photon checkpoints
- Documentation and tests
2025-10-21 07:27:13 +00:00
16 changed files with 2499 additions and 0 deletions

View File

@@ -544,6 +544,8 @@
title: PAG
- local: api/pipelines/paint_by_example
title: Paint by Example
- local: api/pipelines/photon
title: Photon
- local: api/pipelines/pixart
title: PixArt-α
- local: api/pipelines/pixart_sigma

View File

@@ -0,0 +1,131 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Photon
Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
## Available models
Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information.
## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
```py
from diffusers.pipelines.photon import PhotonPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("photon_output.png")
```
### Manual Component Loading
Load components individually to customize the pipeline for instance to use quantized models.
```py
import torch
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PhotonTransformer2DModel.from_pretrained(
"checkpoints/photon-512-t2i-sft",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/photon-512-t2i-sft", subfolder="scheduler"
)
# Load T5Gemma text encoder
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
# Load VAE - choose either Flux VAE or DC-AE
# Flux VAE
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
subfolder="vae",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
pipe = PhotonPipeline(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae
)
pipe.to("cuda")
```
## Memory Optimization
For memory-constrained environments:
```py
import torch
from diffusers.pipelines.photon import PhotonPipeline
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()
```
## PhotonPipeline
[[autodoc]] PhotonPipeline
- all
- __call__
## PhotonPipelineOutput
[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput

View File

@@ -0,0 +1,345 @@
#!/usr/bin/env python3
"""
Script to convert Photon checkpoint from original codebase to diffusers format.
"""
import argparse
import json
import os
import sys
from dataclasses import asdict, dataclass
from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon import PhotonPipeline
DEFAULT_RESOLUTION = 512
@dataclass(frozen=True)
class PhotonBase:
context_in_dim: int = 2304
hidden_size: int = 1792
mlp_ratio: float = 3.5
num_heads: int = 28
depth: int = 16
axes_dim: Tuple[int, int] = (32, 32)
theta: int = 10_000
time_factor: float = 1000.0
time_max_period: int = 10_000
@dataclass(frozen=True)
class PhotonFlux(PhotonBase):
in_channels: int = 16
patch_size: int = 2
@dataclass(frozen=True)
class PhotonDCAE(PhotonBase):
in_channels: int = 32
patch_size: int = 1
def build_config(vae_type: str) -> Tuple[dict, int]:
if vae_type == "flux":
cfg = PhotonFlux()
elif vae_type == "dc-ae":
cfg = PhotonDCAE()
else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
config_dict = asdict(cfg)
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
return config_dict
def create_parameter_mapping(depth: int) -> dict:
"""Create mapping from old parameter names to new diffusers names."""
# Key mappings for structural changes
mapping = {}
# Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention)
for i in range(depth):
# QKV projections moved to attention module
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
# K norm for text tokens moved to attention module
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
# Attention output projection
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
return mapping
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
"""Convert old checkpoint parameters to new diffusers format."""
print("Converting checkpoint parameters...")
mapping = create_parameter_mapping(depth)
converted_state_dict = {}
for key, value in old_state_dict.items():
new_key = key
# Apply specific mappings if needed
if key in mapping:
new_key = mapping[key]
print(f" Mapped: {key} -> {new_key}")
converted_state_dict[new_key] = value
print(f"✓ Converted {len(converted_state_dict)} parameters")
return converted_state_dict
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel:
"""Create and load PhotonTransformer2DModel from old checkpoint."""
print(f"Loading checkpoint from: {checkpoint_path}")
# Load old checkpoint
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Handle different checkpoint formats
if isinstance(old_checkpoint, dict):
if "model" in old_checkpoint:
state_dict = old_checkpoint["model"]
elif "state_dict" in old_checkpoint:
state_dict = old_checkpoint["state_dict"]
else:
state_dict = old_checkpoint
else:
state_dict = old_checkpoint
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
# Convert parameter names if needed
model_depth = int(config.get("depth", 16))
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
# Create transformer with config
print("Creating PhotonTransformer2DModel...")
transformer = PhotonTransformer2DModel(**config)
# Load state dict
print("Loading converted parameters...")
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
if missing_keys:
print(f"⚠ Missing keys: {missing_keys}")
if unexpected_keys:
print(f"⚠ Unexpected keys: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
print("✓ All parameters loaded successfully!")
return transformer
def create_scheduler_config(output_path: str, shift: float):
"""Create FlowMatchEulerDiscreteScheduler config."""
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
scheduler_path = os.path.join(output_path, "scheduler")
os.makedirs(scheduler_path, exist_ok=True)
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
json.dump(scheduler_config, f, indent=2)
print("✓ Created scheduler config")
def download_and_save_vae(vae_type: str, output_path: str):
"""Download and save VAE to local directory."""
from diffusers import AutoencoderDC, AutoencoderKL
vae_path = os.path.join(output_path, "vae")
os.makedirs(vae_path, exist_ok=True)
if vae_type == "flux":
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
else: # dc-ae
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
vae.save_pretrained(vae_path)
print(f"✓ Saved VAE to {vae_path}")
def download_and_save_text_encoder(output_path: str):
"""Download and save T5Gemma text encoder and tokenizer."""
from transformers import GemmaTokenizerFast
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
text_encoder_path = os.path.join(output_path, "text_encoder")
tokenizer_path = os.path.join(output_path, "tokenizer")
os.makedirs(text_encoder_path, exist_ok=True)
os.makedirs(tokenizer_path, exist_ok=True)
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
# Extract and save only the encoder
t5gemma_encoder = t5gemma_model.encoder
t5gemma_encoder.save_pretrained(text_encoder_path)
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
tokenizer.save_pretrained(tokenizer_path)
print(f"✓ Saved tokenizer to {tokenizer_path}")
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
"""Create model_index.json for the pipeline."""
if vae_type == "flux":
vae_class = "AutoencoderKL"
else: # dc-ae
vae_class = "AutoencoderDC"
model_index = {
"_class_name": "PhotonPipeline",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path),
"default_sample_size": default_image_size,
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["photon", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "PhotonTransformer2DModel"],
"vae": ["diffusers", vae_class],
}
model_index_path = os.path.join(output_path, "model_index.json")
with open(model_index_path, "w") as f:
json.dump(model_index, f, indent=2)
def main(args):
# Validate inputs
if not os.path.exists(args.checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
config = build_config(args.vae_type)
# Create output directory
os.makedirs(args.output_path, exist_ok=True)
print(f"✓ Output directory: {args.output_path}")
# Create transformer from checkpoint
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
# Save transformer
transformer_path = os.path.join(args.output_path, "transformer")
os.makedirs(transformer_path, exist_ok=True)
# Save config
with open(os.path.join(transformer_path, "config.json"), "w") as f:
json.dump(config, f, indent=2)
# Save model weights as safetensors
state_dict = transformer.state_dict()
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
print(f"✓ Saved transformer to {transformer_path}")
# Create scheduler config
create_scheduler_config(args.output_path, args.shift)
download_and_save_vae(args.vae_type, args.output_path)
download_and_save_text_encoder(args.output_path)
# Create model_index.json
create_model_index(args.vae_type, args.resolution, args.output_path)
# Verify the pipeline can be loaded
try:
pipeline = PhotonPipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}")
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
# Display model info
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
print(f"✓ Transformer parameters: {num_params:,}")
except Exception as e:
print(f"Pipeline verification failed: {e}")
return False
print("Conversion completed successfully!")
print(f"Converted pipeline saved to: {args.output_path}")
print(f"VAE type: {args.vae_type}")
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
)
parser.add_argument(
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
)
parser.add_argument(
"--vae_type",
type=str,
choices=["flux", "dc-ae"],
required=True,
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
)
parser.add_argument(
"--resolution",
type=int,
choices=[256, 512, 1024],
default=DEFAULT_RESOLUTION,
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
)
parser.add_argument(
"--shift",
type=float,
default=3.0,
help="Shift for the scheduler",
)
args = parser.parse_args()
try:
success = main(args)
if not success:
sys.exit(1)
except Exception as e:
print(f"Conversion failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -232,6 +232,7 @@ else:
"MultiControlNetModel",
"OmniGenTransformer2DModel",
"ParallelConfig",
"PhotonTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"QwenImageControlNetModel",
@@ -515,6 +516,7 @@ else:
"MusicLDMPipeline",
"OmniGenPipeline",
"PaintByExamplePipeline",
"PhotonPipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
@@ -926,6 +928,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiControlNetModel,
OmniGenTransformer2DModel,
ParallelConfig,
PhotonTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageControlNetModel,
@@ -1179,6 +1182,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MusicLDMPipeline,
OmniGenPipeline,
PaintByExamplePipeline,
PhotonPipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,

View File

@@ -96,6 +96,7 @@ if is_torch_available():
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
@@ -190,6 +191,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LuminaNextDiT2DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PhotonTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageTransformer2DModel,

View File

@@ -32,6 +32,7 @@ if is_torch_available():
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_photon import PhotonTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel

View File

@@ -0,0 +1,768 @@
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.nn.functional import fold, unfold
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..embeddings import get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
logger = logging.get_logger(__name__)
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
r"""
Generates 2D patch coordinate indices for a batch of images.
Args:
batch_size (`int`):
Number of images in the batch.
height (`int`):
Height of the input images (in pixels).
width (`int`):
Width of the input images (in pixels).
patch_size (`int`):
Size of the square patches that the image is divided into.
device (`torch.device`):
The device on which to create the tensor.
Returns:
`torch.Tensor`:
Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
image grid.
"""
img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
r"""
Applies rotary positional embeddings (RoPE) to a query tensor.
Args:
xq (`torch.Tensor`):
Input tensor of shape `(..., dim)` representing the queries.
freqs_cis (`torch.Tensor`):
Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.
Returns:
`torch.Tensor`:
Tensor of the same shape as `xq` with rotary embeddings applied.
"""
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
# Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq)
class PhotonAttnProcessor2_0:
r"""
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: "PhotonAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Apply Photon attention using PhotonAttention module.
Args:
attn: PhotonAttention module containing projection layers
hidden_states: Image tokens [B, L_img, D]
encoder_hidden_states: Text tokens [B, L_txt, D]
attention_mask: Boolean mask for text tokens [B, L_txt]
image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
"""
if encoder_hidden_states is None:
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
# Project image tokens to Q, K, V
img_qkv = attn.img_qkv_proj(hidden_states)
B, L_img, _ = img_qkv.shape
img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim)
img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D]
img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
# Apply QK normalization to image tokens
img_q = attn.norm_q(img_q)
img_k = attn.norm_k(img_k)
# Project text tokens to K, V
txt_kv = attn.txt_kv_proj(encoder_hidden_states)
B, L_txt, _ = txt_kv.shape
txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim)
txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D]
txt_k, txt_v = txt_kv[0], txt_kv[1]
# Apply K normalization to text tokens
txt_k = attn.norm_added_k(txt_k)
# Apply RoPE to image queries and keys
if image_rotary_emb is not None:
img_q = apply_rope(img_q, image_rotary_emb)
img_k = apply_rope(img_k, image_rotary_emb)
# Concatenate text and image keys/values
k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D]
v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D]
# Build attention mask if provided
attn_mask_tensor = None
if attention_mask is not None:
bs, _, l_img, _ = img_q.shape
l_txt = txt_k.shape[2]
if attention_mask.dim() != 2:
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
if attention_mask.shape[-1] != l_txt:
raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")
device = img_q.device
ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
attention_mask = attention_mask.to(device=device, dtype=torch.bool)
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
# Apply attention using dispatch_attention_fn for backend support
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
query = img_q.transpose(1, 2) # [B, L_img, H, D]
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
attn_output = dispatch_attention_fn(
query,
key,
value,
attn_mask=attn_mask_tensor,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
# Reshape from [B, L_img, H, D] to [B, L_img, H*D]
batch_size, seq_len, num_heads, head_dim = attn_output.shape
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
# Apply output projection
attn_output = attn.to_out[0](attn_output)
if len(attn.to_out) > 1:
attn_output = attn.to_out[1](attn_output) # dropout if present
return attn_output
class PhotonAttention(nn.Module, AttentionModuleMixin):
r"""
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
Photon's architecture.
"""
_default_processor_cls = PhotonAttnProcessor2_0
_available_processors = [PhotonAttnProcessor2_0]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
bias: bool = False,
out_bias: bool = False,
eps: float = 1e-6,
processor=None,
):
super().__init__()
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.query_dim = query_dim
self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias)
self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias)
self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(0.0))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
**kwargs,
)
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PhotonEmbedND(nn.Module):
r"""
N-dimensional rotary positional embedding.
This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
dimension. The embeddings are combined and returned as a single tensor
Args:
dim (int):
Base embedding dimension (must be even).
theta (int):
Scaling factor that controls the frequency spectrum of the rotary embeddings.
axes_dim (list[int]):
List of embedding dimensions for each axis (each must be even).
"""
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
# Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
# out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
out = out.reshape(*out.shape[:-1], 2, 2)
return out.float()
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
class MLPEmbedder(nn.Module):
r"""
A simple 2-layer MLP used for embedding inputs.
Args:
in_dim (`int`):
Dimensionality of the input features.
hidden_dim (`int`):
Dimensionality of the hidden and output embedding space.
Returns:
`torch.Tensor`:
Tensor of shape `(..., hidden_dim)` containing the embedded representations.
"""
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class Modulation(nn.Module):
r"""
Modulation network that generates scale, shift, and gating parameters.
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
two tuples `(shift, scale, gate)`.
Args:
dim (`int`):
Dimensionality of the input vector. The output will have `6 * dim` features internally.
Returns:
((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
Two tuples `(shift, scale, gate)`.
"""
def __init__(self, dim: int):
super().__init__()
self.lin = nn.Linear(dim, 6 * dim, bias=True)
nn.init.constant_(self.lin.weight, 0)
nn.init.constant_(self.lin.bias, 0)
def forward(self, vec: torch.Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
return tuple(out[:3]), tuple(out[3:])
class PhotonBlock(nn.Module):
r"""
Multimodal transformer block with textimage cross-attention, modulation, and MLP.
Args:
hidden_size (`int`):
Dimension of the hidden representations.
num_heads (`int`):
Number of attention heads.
mlp_ratio (`float`, *optional*, defaults to 4.0):
Expansion ratio for the hidden dimension inside the MLP.
qk_scale (`float`, *optional*):
Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.
Attributes:
img_pre_norm (`nn.LayerNorm`):
Pre-normalization applied to image tokens before attention.
attention (`PhotonAttention`):
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
image and text tokens.
post_attention_layernorm (`nn.LayerNorm`):
Normalization applied after attention.
gate_proj / up_proj / down_proj (`nn.Linear`):
Feedforward layers forming the gated MLP.
mlp_act (`nn.GELU`):
Nonlinear activation used in the MLP.
modulation (`Modulation`):
Produces scale/shift/gating parameters for modulated layers.
Methods:
The forward method performs cross-attention and the MLP with modulation.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.hidden_size = hidden_size
# Pre-attention normalization for image tokens
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# PhotonAttention module with built-in projections and norms
self.attention = PhotonAttention(
query_dim=hidden_size,
heads=num_heads,
dim_head=self.head_dim,
bias=False,
out_bias=False,
eps=1e-6,
processor=PhotonAttnProcessor2_0(),
)
# mlp
self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: torch.Tensor,
attention_mask: Optional[Tensor] = None,
**kwargs: dict[str, Any],
) -> torch.Tensor:
r"""
Runs modulation-gated cross-attention and MLP, with residual connections.
Args:
hidden_states (`torch.Tensor`):
Image tokens of shape `(B, L_img, hidden_size)`.
encoder_hidden_states (`torch.Tensor`):
Text tokens of shape `(B, L_txt, hidden_size)`.
temb (`torch.Tensor`):
Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
broadcastable).
image_rotary_emb (`torch.Tensor`):
Rotary positional embeddings applied inside attention.
attention_mask (`torch.Tensor`, *optional*):
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
**kwargs:
Additional keyword arguments for API compatibility.
Returns:
`torch.Tensor`:
Updated image tokens of shape `(B, L_img, hidden_size)`.
"""
mod_attn, mod_mlp = self.modulation(temb)
attn_shift, attn_scale, attn_gate = mod_attn
mlp_shift, mlp_scale, mlp_gate = mod_mlp
hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
attn_out = self.attention(
hidden_states=hidden_states_mod,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_gate * attn_out
x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
return hidden_states
class FinalLayer(nn.Module):
r"""
Final projection layer with adaptive LayerNorm modulation.
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
outputs.
Args:
hidden_size (`int`):
Dimensionality of the input tokens.
patch_size (`int`):
Size of the square image patches.
out_channels (`int`):
Number of output channels per pixel (e.g. RGB = 3).
Forward Inputs:
x (`torch.Tensor`):
Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
vec (`torch.Tensor`):
Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
LayerNorm.
Returns:
`torch.Tensor`:
Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
"""
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
r"""
Flattens an image tensor into a sequence of non-overlapping patches.
Args:
img (`torch.Tensor`):
Input image tensor of shape `(B, C, H, W)`.
patch_size (`int`):
Size of each square patch. Must evenly divide both `H` and `W`.
Returns:
`torch.Tensor`:
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
// patch_size)` is the number of patches.
"""
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
r"""
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
Args:
seq (`torch.Tensor`):
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
patch_size)`.
patch_size (`int`):
Size of each square patch.
shape (`tuple` or `torch.Tensor`):
The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
height and width.
Returns:
`torch.Tensor`:
Reconstructed image tensor of shape `(B, C, H, W)`.
"""
if isinstance(shape, tuple):
shape = shape[-2:]
elif isinstance(shape, torch.Tensor):
shape = (int(shape[0]), int(shape[1]))
else:
raise NotImplementedError(f"shape type {type(shape)} not supported")
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
Transformer-based 2D model for text to image generation.
Args:
in_channels (`int`, *optional*, defaults to 16):
Number of input channels in the latent image.
patch_size (`int`, *optional*, defaults to 2):
Size of the square patches used to flatten the input image.
context_in_dim (`int`, *optional*, defaults to 2304):
Dimensionality of the text conditioning input.
hidden_size (`int`, *optional*, defaults to 1792):
Dimension of the hidden representation.
mlp_ratio (`float`, *optional*, defaults to 3.5):
Expansion ratio for the hidden dimension inside MLP blocks.
num_heads (`int`, *optional*, defaults to 28):
Number of attention heads.
depth (`int`, *optional*, defaults to 16):
Number of transformer blocks.
axes_dim (`list[int]`, *optional*):
List of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
theta (`int`, *optional*, defaults to 10000):
Frequency scaling factor for rotary embeddings.
time_factor (`float`, *optional*, defaults to 1000.0):
Scaling factor applied in timestep embeddings.
time_max_period (`int`, *optional*, defaults to 10000):
Maximum frequency period for timestep embeddings.
Attributes:
pe_embedder (`EmbedND`):
Multi-axis rotary embedding generator for positional encodings.
img_in (`nn.Linear`):
Projection layer for image patch tokens.
time_in (`MLPEmbedder`):
Embedding layer for timestep embeddings.
txt_in (`nn.Linear`):
Projection layer for text conditioning.
blocks (`nn.ModuleList`):
Stack of transformer blocks (`PhotonBlock`).
final_layer (`LastLayer`):
Projection layer mapping hidden tokens back to patch outputs.
Methods:
attn_processors:
Returns a dictionary of all attention processors in the model.
set_attn_processor(processor):
Replaces attention processors across all attention layers.
process_inputs(image_latent, txt):
Converts inputs into patch tokens, encodes text, and produces positional encodings.
compute_timestep_embedding(timestep, dtype):
Creates a timestep embedding of dimension 256, scaled and projected.
forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
**block_kwargs):
Runs the sequence of transformer blocks over image and text tokens.
forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
attention_kwargs=None, return_dict=True):
Full forward pass from latent input to reconstructed output image.
Returns:
`Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
- `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
"""
config_name = "config.json"
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 16,
patch_size: int = 2,
context_in_dim: int = 2304,
hidden_size: int = 1792,
mlp_ratio: float = 3.5,
num_heads: int = 28,
depth: int = 16,
axes_dim: list = None,
theta: int = 10000,
time_factor: float = 1000.0,
time_max_period: int = 10000,
):
super().__init__()
if axes_dim is None:
axes_dim = [32, 32]
# Store parameters directly
self.in_channels = in_channels
self.patch_size = patch_size
self.out_channels = self.in_channels * self.patch_size**2
self.time_factor = time_factor
self.time_max_period = time_max_period
if hidden_size % num_heads != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.blocks = nn.ModuleList(
[
PhotonBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
)
for i in range(depth)
]
)
self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels)
self.gradient_checkpointing = False
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return self.time_in(
get_timestep_embedding(
timesteps=timestep,
embedding_dim=256,
max_period=self.time_max_period,
scale=self.time_factor,
flip_sin_to_cos=True, # Match original cos, sin order
).to(dtype)
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
r"""
Forward pass of the PhotonTransformer2DModel.
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
Args:
hidden_states (`torch.Tensor`):
Input latent image tensor of shape `(B, C, H, W)`.
timestep (`torch.Tensor`):
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
encoder_hidden_states (`torch.Tensor`):
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
attention_mask (`torch.Tensor`, *optional*):
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
attention_kwargs (`dict`, *optional*):
Additional arguments passed to attention layers.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a `Transformer2DModelOutput` or a tuple.
Returns:
`Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
"""
# Process text conditioning
txt = self.txt_in(encoder_hidden_states)
# Convert image to sequence and embed
img = img2seq(hidden_states, self.patch_size)
img = self.img_in(img)
# Generate positional embeddings
bs, _, h, w = hidden_states.shape
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
pe = self.pe_embedder(img_ids)
# Compute time embedding
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
# Apply transformer blocks
for block in self.blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
img = self._gradient_checkpointing_func(
block.__call__,
img,
txt,
vec,
pe,
attention_mask,
)
else:
img = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=pe,
attention_mask=attention_mask,
)
# Final layer and convert back to image
img = self.final_layer(img, vec)
output = seq2img(img, self.patch_size, hidden_states.shape)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -144,6 +144,7 @@ else:
"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
_import_structure["photon"] = ["PhotonPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -717,6 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPAGPipeline,
)
from .paint_by_example import PaintByExamplePipeline
from .photon import PhotonPipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import (

View File

@@ -0,0 +1,63 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_photon"] = ["PhotonPipeline"]
# Import T5GemmaEncoder for pipeline loading compatibility
try:
if is_transformers_available():
import transformers
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
_additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
# Patch transformers module directly for serialization
if not hasattr(transformers, "T5GemmaEncoder"):
transformers.T5GemmaEncoder = T5GemmaEncoder
except ImportError:
pass
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_output import PhotonPipelineOutput
from .pipeline_photon import PhotonPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,35 @@
# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class PhotonPipelineOutput(BaseOutput):
"""
Output class for Photon pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -0,0 +1,768 @@
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import re
import urllib.parse as ul
from typing import Callable, Dict, List, Optional, Union
import ftfy
import torch
from transformers import (
AutoTokenizer,
GemmaTokenizerFast,
T5TokenizerFast,
)
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.image_processor import PixArtImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
DEFAULT_RESOLUTION = 512
ASPECT_RATIO_256_BIN = {
"0.46": [160, 352],
"0.6": [192, 320],
"0.78": [224, 288],
"1.0": [256, 256],
"1.29": [288, 224],
"1.67": [320, 192],
"2.2": [352, 160],
}
ASPECT_RATIO_512_BIN = {
"0.5": [352, 704],
"0.57": [384, 672],
"0.6": [384, 640],
"0.68": [416, 608],
"0.78": [448, 576],
"0.88": [480, 544],
"1.0": [512, 512],
"1.13": [544, 480],
"1.29": [576, 448],
"1.46": [608, 416],
"1.67": [640, 384],
"1.75": [672, 384],
"2.0": [704, 352],
}
logger = logging.get_logger(__name__)
class TextPreprocessor:
"""Text preprocessing utility for PhotonPipeline."""
def __init__(self):
"""Initialize text preprocessor."""
self.bad_punct_regex = re.compile(
r"["
+ "#®•©™&@·º½¾¿¡§~"
+ r"\)"
+ r"\("
+ r"\]"
+ r"\["
+ r"\}"
+ r"\{"
+ r"\|"
+ r"\\"
+ r"\/"
+ r"\*"
+ r"]{1,}"
)
def clean_text(self, text: str) -> str:
"""Clean text using comprehensive text processing logic."""
# See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
text = str(text)
text = ul.unquote_plus(text)
text = text.strip().lower()
text = re.sub("<person>", "person", text)
# Remove all urls:
text = re.sub(
r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))",
"",
text,
) # regex for urls
# @<nickname>
text = re.sub(r"@[\w\d]+\b", "", text)
# 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs
text = re.sub(r"[\u31c0-\u31ef]+", "", text)
text = re.sub(r"[\u31f0-\u31ff]+", "", text)
text = re.sub(r"[\u3200-\u32ff]+", "", text)
text = re.sub(r"[\u3300-\u33ff]+", "", text)
text = re.sub(r"[\u3400-\u4dbf]+", "", text)
text = re.sub(r"[\u4dc0-\u4dff]+", "", text)
text = re.sub(r"[\u4e00-\u9fff]+", "", text)
# все виды тире / all types of dash --> "-"
text = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
"-",
text,
)
# кавычки к одному стандарту
text = re.sub(r"[`´«»" "¨]", '"', text)
text = re.sub(r"['']", "'", text)
# &quot; and &amp
text = re.sub(r"&quot;?", "", text)
text = re.sub(r"&amp", "", text)
# ip addresses:
text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text)
# article ids:
text = re.sub(r"\d:\d\d\s+$", "", text)
# \n
text = re.sub(r"\\n", " ", text)
# "#123", "#12345..", "123456.."
text = re.sub(r"#\d{1,3}\b", "", text)
text = re.sub(r"#\d{5,}\b", "", text)
text = re.sub(r"\b\d{6,}\b", "", text)
# filenames:
text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text)
# Clean punctuation
text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT"""
text = re.sub(r"[\.]{2,}", r" ", text)
text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT
text = re.sub(r"\s+\.\s+", r" ", text) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, text)) > 3:
text = re.sub(regex2, " ", text)
# Basic cleaning
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
text = text.strip()
# Clean alphanumeric patterns
text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640
text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc
text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231
# Common spam patterns
text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text)
text = re.sub(r"(free\s)?download(\sfree)?", "", text)
text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text)
text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text)
text = re.sub(r"\bpage\s+\d+\b", "", text)
text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a...
text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text)
# Final cleanup
text = re.sub(r"\b\s+\:\s+", r": ", text)
text = re.sub(r"(\D[,\./])\b", r"\1 ", text)
text = re.sub(r"\s+", " ", text)
text.strip()
text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text)
text = re.sub(r"^[\'\_,\-\:;]", r"", text)
text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text)
text = re.sub(r"^\.\S+$", "", text)
return text.strip()
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import PhotonPipeline
>>> # Load pipeline with from_pretrained
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft")
>>> pipe.to("cuda")
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
>>> image.save("photon_output.png")
```
"""
class PhotonPipeline(
DiffusionPipeline,
LoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
r"""
Pipeline for text-to-image generation using Photon Transformer.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
transformer ([`PhotonTransformer2DModel`]):
The Photon transformer model to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
text_encoder ([`T5GemmaEncoder`]):
Text encoder model for encoding prompts.
tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]):
Tokenizer for the text encoder.
vae ([`AutoencoderKL`] or [`AutoencoderDC`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_optional_components = ["vae"]
def __init__(
self,
transformer: PhotonTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
text_encoder: T5GemmaEncoder,
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None,
default_sample_size: Optional[int] = DEFAULT_RESOLUTION,
):
super().__init__()
if PhotonTransformer2DModel is None:
raise ImportError(
"PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed."
)
self.text_preprocessor = TextPreprocessor()
self.default_sample_size = default_sample_size
self._guidance_scale = 1.0
self.register_modules(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
)
self.register_to_config(default_sample_size=self.default_sample_size)
if vae is not None:
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
else:
self.image_processor = None
@property
def vae_scale_factor(self):
if self.vae is None:
return 8
if hasattr(self.vae, "spatial_compression_ratio"):
return self.vae.spatial_compression_ratio
else: # Flux VAE
return 2 ** (len(self.vae.config.block_out_channels) - 1)
@property
def do_classifier_free_guidance(self):
"""Check if classifier-free guidance is enabled based on guidance scale."""
return self._guidance_scale > 1.0
@property
def guidance_scale(self):
return self._guidance_scale
def get_default_resolution(self):
"""Determine the default resolution based on the loaded VAE and config.
Returns:
int: The default sample size (height/width) to use for generation.
"""
default_from_config = getattr(self.config, "default_sample_size", None)
if default_from_config is not None:
return default_from_config
return DEFAULT_RESOLUTION
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
):
"""Prepare initial latents for the diffusion process."""
if latents is None:
spatial_compression = self.vae_scale_factor
latent_height, latent_width = (
height // spatial_compression,
width // spatial_compression,
)
shape = (batch_size, num_channels_latents, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
return latents
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.BoolTensor] = None,
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
):
"""Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings."""
if device is None:
device = self._execution_device
if prompt_embeds is None:
if isinstance(prompt, str):
prompt = [prompt]
# Encode the prompts
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
)
# Duplicate embeddings for each generation per prompt
if num_images_per_prompt > 1:
# Repeat prompt embeddings
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# Repeat negative embeddings if using CFG
if do_classifier_free_guidance and negative_prompt_embeds is not None:
bs_embed, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if negative_prompt_attention_mask is not None:
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
return (
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds if do_classifier_free_guidance else None,
negative_prompt_attention_mask if do_classifier_free_guidance else None,
)
def _tokenize_prompts(self, prompts: List[str], device: torch.device):
"""Tokenize and clean prompts."""
cleaned = [self.text_preprocessor.clean_text(text) for text in prompts]
tokens = self.tokenizer(
cleaned,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device)
def _encode_prompt_standard(
self,
prompt: List[str],
device: torch.device,
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
):
"""Encode prompt using standard text encoder and tokenizer with batch processing."""
batch_size = len(prompt)
if do_classifier_free_guidance:
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
prompts_to_encode = negative_prompt + prompt
else:
prompts_to_encode = prompt
input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device)
with torch.no_grad():
embeddings = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)["last_hidden_state"]
if do_classifier_free_guidance:
uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0)
uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0)
else:
text_embeddings = embeddings
cross_attn_mask = attention_mask
uncond_text_embeddings = None
uncond_cross_attn_mask = None
return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
def check_inputs(
self,
prompt: Union[str, List[str]],
height: int,
width: int,
guidance_scale: float,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
"""Check that all inputs are in correct format."""
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
if prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None:
raise ValueError(
"When `prompt_embeds` is provided and `guidance_scale > 1.0`, "
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
)
spatial_compression = self.vae_scale_factor
if height % spatial_compression != 0 or width % spatial_compression != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}."
)
if guidance_scale < 1.0:
raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}")
if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 4.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.BoolTensor] = None,
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
use_resolution_binning: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
instead.
negative_prompt (`str`, *optional*, defaults to `""`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 28):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an
empty string.
prompt_attention_mask (`torch.BoolTensor`, *optional*):
Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated
from `prompt` input argument.
negative_prompt_attention_mask (`torch.BoolTensor`, *optional*):
Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`,
attention mask will be generated from an empty string.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
to the requested resolution. Useful for generating non-square images at optimal resolutions.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`.
`callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed
in the `._callback_tensor_inputs` attribute.
Examples:
Returns:
[`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
# 0. Set height and width
default_resolution = self.get_default_resolution()
height = height or default_resolution
width = width or default_resolution
if use_resolution_binning:
if self.image_processor is None:
raise ValueError(
"Resolution binning requires a VAE with image_processor, but VAE is not available. "
"Set use_resolution_binning=False or provide a VAE."
)
if self.default_sample_size <= 256:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
# Store original dimensions
orig_height, orig_width = height, width
# Map to closest resolution in the bin
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
# 1. Check inputs
self.check_inputs(
prompt,
height,
width,
guidance_scale,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
if self.vae is None and output_type not in ["latent", "pt"]:
raise ValueError(
f"VAE is required for output_type='{output_type}' but it is not available. "
"Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs."
)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Use execution device (handles offloading scenarios including group offloading)
device = self._execution_device
self._guidance_scale = guidance_scale
# 2. Encode input prompt
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
prompt,
device,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
# Expose standard names for callbacks parity
prompt_embeds = text_embeddings
negative_prompt_embeds = uncond_text_embeddings
# 3. Prepare timesteps
if timesteps is not None:
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self.num_timesteps = len(timesteps)
# 4. Prepare latent variables
if self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
else:
# When vae is None, get latent channels from transformer
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 5. Prepare extra step kwargs
extra_step_kwargs = {}
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_eta:
extra_step_kwargs["eta"] = 0.0
# 6. Prepare cross-attention embeddings and masks
if self.do_classifier_free_guidance:
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
ca_mask = None
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
else:
ca_embed = text_embeddings
ca_mask = cross_attn_mask
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Duplicate latents if using classifier-free guidance
if self.do_classifier_free_guidance:
latents_in = torch.cat([latents, latents], dim=0)
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
else:
latents_in = latents
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
# Forward through transformer
noise_pred = self.transformer(
hidden_states=latents_in,
timestep=t_cont,
encoder_hidden_states=ca_embed,
attention_mask=ca_mask,
return_dict=False,
)[0]
# Apply CFG
if self.do_classifier_free_guidance:
noise_uncond, noise_text = noise_pred.chunk(2, dim=0)
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
# Compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_on_step_end(self, i, t, callback_kwargs)
# Call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 8. Post-processing
if output_type == "latent" or (output_type == "pt" and self.vae is None):
image = latents
else:
# Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)
scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
latents = (latents / scaling_factor) + shift_factor
# Decode using VAE (AutoencoderKL or AutoencoderDC)
image = self.vae.decode(latents, return_dict=False)[0]
# Resize back to original resolution if using binning
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
# Use standard image processor for post-processing
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return PhotonPipelineOutput(images=image)

View File

@@ -1098,6 +1098,21 @@ class ParallelConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class PhotonTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1847,6 +1847,21 @@ class PaintByExamplePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PhotonPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class PIAPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,83 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PhotonTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16, 16)
@property
def output_shape(self):
return (16, 16, 16)
def prepare_dummy_input(self, height=16, width=16):
batch_size = 1
num_latent_channels = 16
sequence_length = 16
embedding_dim = 1792
hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 16,
"patch_size": 2,
"context_in_dim": 1792,
"hidden_size": 1792,
"mlp_ratio": 3.5,
"num_heads": 28,
"depth": 4, # Smaller depth for testing
"axes_dim": [32, 32],
"theta": 10_000,
}
inputs_dict = self.prepare_dummy_input()
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"PhotonTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
if __name__ == "__main__":
unittest.main()

View File

View File

@@ -0,0 +1,265 @@
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_transformers_version
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@pytest.mark.xfail(
condition=is_transformers_version(">", "4.57.1"),
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False,
)
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PhotonPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
@classmethod
def setUpClass(cls):
# Ensure PhotonPipeline has an _execution_device property expected by __call__
if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property):
try:
setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
except Exception:
pass
def get_dummy_components(self):
torch.manual_seed(0)
transformer = PhotonTransformer2DModel(
patch_size=1,
in_channels=4,
context_in_dim=8,
hidden_size=8,
mlp_ratio=2.0,
num_heads=2,
depth=1,
axes_dim=[2, 2],
)
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=4,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0,
scaling_factor=1.0,
).eval()
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
tokenizer.model_max_length = 64
torch.manual_seed(0)
encoder_params = {
"vocab_size": tokenizer.vocab_size,
"hidden_size": 8,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 1,
"head_dim": 4,
"max_position_embeddings": 64,
"layer_types": ["full_attention"],
"attention_bias": False,
"attention_dropout": 0.0,
"dropout_rate": 0.0,
"hidden_activation": "gelu_pytorch_tanh",
"rms_norm_eps": 1e-06,
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"query_pre_attn_scalar": 4,
"rope_theta": 10000.0,
"sliding_window": 4096,
}
encoder_config = T5GemmaModuleConfig(**encoder_params)
text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
text_encoder = T5GemmaEncoder(text_encoder_config)
return {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
return {
"prompt": "",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"output_type": "pt",
"use_resolution_binning": False,
}
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = PhotonPipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
try:
pipe.register_to_config(_execution_device="cpu")
except Exception:
pass
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_image = torch.zeros(3, 32, 32)
max_diff = np.abs(generated_image - expected_image).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
components = self.get_dummy_components()
pipe = PhotonPipeline(**components)
pipe = pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
try:
pipe.register_to_config(_execution_device="cpu")
except Exception:
pass
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs("cpu")
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
_ = pipe(**inputs)[0]
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
_ = pipe(**inputs)[0]
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
def to_np_local(tensor):
if isinstance(tensor, torch.Tensor):
return tensor.detach().cpu().numpy()
return tensor
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max()
max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max()
self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
def test_inference_with_autoencoder_dc(self):
"""Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
device = "cpu"
components = self.get_dummy_components()
torch.manual_seed(0)
vae_dc = AutoencoderDC(
in_channels=3,
latent_channels=4,
attention_head_dim=2,
encoder_block_types=(
"ResBlock",
"EfficientViTBlock",
),
decoder_block_types=(
"ResBlock",
"EfficientViTBlock",
),
encoder_block_out_channels=(8, 8),
decoder_block_out_channels=(8, 8),
encoder_qkv_multiscales=((), (5,)),
decoder_qkv_multiscales=((), (5,)),
encoder_layers_per_block=(1, 1),
decoder_layers_per_block=(1, 1),
upsample_block_type="interpolate",
downsample_block_type="stride_conv",
decoder_norm_types="rms_norm",
decoder_act_fns="silu",
).eval()
components["vae"] = vae_dc
pipe = PhotonPipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
expected_scale_factor = vae_dc.spatial_compression_ratio
self.assertEqual(pipe.vae_scale_factor, expected_scale_factor)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_image = torch.zeros(3, 32, 32)
max_diff = np.abs(generated_image - expected_image).max()
self.assertLessEqual(max_diff, 1e10)