mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Prx (#12525)
* rename photon to prx * rename photon into prx * Revert .gitignore to state before commitb7fb0fe9d6* rename photon to prx * rename photon into prx * Revert .gitignore to state before commitb7fb0fe9d6* make fix-copies
This commit is contained in:
@@ -541,12 +541,12 @@
|
||||
title: PAG
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/photon
|
||||
title: Photon
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/prx
|
||||
title: PRX
|
||||
- local: api/pipelines/qwenimage
|
||||
title: QwenImage
|
||||
- local: api/pipelines/sana
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Photon
|
||||
|
||||
|
||||
Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
|
||||
|
||||
## Available models
|
||||
|
||||
Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
|
||||
|
||||
|
||||
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|
||||
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
|
||||
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information.
|
||||
|
||||
## Loading the pipeline
|
||||
|
||||
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
|
||||
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
|
||||
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
image.save("photon_output.png")
|
||||
```
|
||||
|
||||
### Manual Component Loading
|
||||
|
||||
Load components individually to customize the pipeline for instance to use quantized models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
from diffusers.models import AutoencoderKL, AutoencoderDC
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from transformers import T5GemmaModel, GemmaTokenizerFast
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
# Load transformer
|
||||
transformer = PhotonTransformer2DModel.from_pretrained(
|
||||
"checkpoints/photon-512-t2i-sft",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"checkpoints/photon-512-t2i-sft", subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Load T5Gemma text encoder
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
|
||||
# Load VAE - choose either Flux VAE or DC-AE
|
||||
# Flux VAE
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
|
||||
subfolder="vae",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = PhotonPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae
|
||||
)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
For memory-constrained environments:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
|
||||
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
|
||||
|
||||
# Or use sequential CPU offload for even lower memory
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
## PhotonPipeline
|
||||
|
||||
[[autodoc]] PhotonPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## PhotonPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput
|
||||
131
docs/source/en/api/pipelines/prx.md
Normal file
131
docs/source/en/api/pipelines/prx.md
Normal file
@@ -0,0 +1,131 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# PRX
|
||||
|
||||
|
||||
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
|
||||
|
||||
## Available models
|
||||
|
||||
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
|
||||
|
||||
|
||||
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|
||||
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
|
||||
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
|
||||
|
||||
## Loading the pipeline
|
||||
|
||||
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
image.save("prx_output.png")
|
||||
```
|
||||
|
||||
### Manual Component Loading
|
||||
|
||||
Load components individually to customize the pipeline for instance to use quantized models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
from diffusers.models import AutoencoderKL, AutoencoderDC
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from transformers import T5GemmaModel, GemmaTokenizerFast
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
# Load transformer
|
||||
transformer = PRXTransformer2DModel.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Load T5Gemma text encoder
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
|
||||
# Load VAE - choose either Flux VAE or DC-AE
|
||||
# Flux VAE
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
|
||||
subfolder="vae",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = PRXPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae
|
||||
)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
For memory-constrained environments:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
|
||||
|
||||
# Or use sequential CPU offload for even lower memory
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
## PRXPipeline
|
||||
|
||||
[[autodoc]] PRXPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## PRXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to convert Photon checkpoint from original codebase to diffusers format.
|
||||
Script to convert PRX checkpoint from original codebase to diffusers format.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -13,15 +13,15 @@ from typing import Dict, Tuple
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonBase:
|
||||
class PRXBase:
|
||||
context_in_dim: int = 2304
|
||||
hidden_size: int = 1792
|
||||
mlp_ratio: float = 3.5
|
||||
@@ -34,22 +34,22 @@ class PhotonBase:
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonFlux(PhotonBase):
|
||||
class PRXFlux(PRXBase):
|
||||
in_channels: int = 16
|
||||
patch_size: int = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonDCAE(PhotonBase):
|
||||
class PRXDCAE(PRXBase):
|
||||
in_channels: int = 32
|
||||
patch_size: int = 1
|
||||
|
||||
|
||||
def build_config(vae_type: str) -> Tuple[dict, int]:
|
||||
if vae_type == "flux":
|
||||
cfg = PhotonFlux()
|
||||
cfg = PRXFlux()
|
||||
elif vae_type == "dc-ae":
|
||||
cfg = PhotonDCAE()
|
||||
cfg = PRXDCAE()
|
||||
else:
|
||||
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
|
||||
|
||||
@@ -64,7 +64,7 @@ def create_parameter_mapping(depth: int) -> dict:
|
||||
# Key mappings for structural changes
|
||||
mapping = {}
|
||||
|
||||
# Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention)
|
||||
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
|
||||
for i in range(depth):
|
||||
# QKV projections moved to attention module
|
||||
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
|
||||
@@ -108,8 +108,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel:
|
||||
"""Create and load PhotonTransformer2DModel from old checkpoint."""
|
||||
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
|
||||
"""Create and load PRXTransformer2DModel from old checkpoint."""
|
||||
|
||||
print(f"Loading checkpoint from: {checkpoint_path}")
|
||||
|
||||
@@ -137,8 +137,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph
|
||||
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
|
||||
|
||||
# Create transformer with config
|
||||
print("Creating PhotonTransformer2DModel...")
|
||||
transformer = PhotonTransformer2DModel(**config)
|
||||
print("Creating PRXTransformer2DModel...")
|
||||
transformer = PRXTransformer2DModel(**config)
|
||||
|
||||
# Load state dict
|
||||
print("Loading converted parameters...")
|
||||
@@ -221,14 +221,14 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str)
|
||||
vae_class = "AutoencoderDC"
|
||||
|
||||
model_index = {
|
||||
"_class_name": "PhotonPipeline",
|
||||
"_class_name": "PRXPipeline",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"_name_or_path": os.path.basename(output_path),
|
||||
"default_sample_size": default_image_size,
|
||||
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
||||
"text_encoder": ["photon", "T5GemmaEncoder"],
|
||||
"text_encoder": ["prx", "T5GemmaEncoder"],
|
||||
"tokenizer": ["transformers", "GemmaTokenizerFast"],
|
||||
"transformer": ["diffusers", "PhotonTransformer2DModel"],
|
||||
"transformer": ["diffusers", "PRXTransformer2DModel"],
|
||||
"vae": ["diffusers", vae_class],
|
||||
}
|
||||
|
||||
@@ -275,7 +275,7 @@ def main(args):
|
||||
|
||||
# Verify the pipeline can be loaded
|
||||
try:
|
||||
pipeline = PhotonPipeline.from_pretrained(args.output_path)
|
||||
pipeline = PRXPipeline.from_pretrained(args.output_path)
|
||||
print("Pipeline loaded successfully!")
|
||||
print(f"Transformer: {type(pipeline.transformer).__name__}")
|
||||
print(f"VAE: {type(pipeline.vae).__name__}")
|
||||
@@ -298,10 +298,10 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
|
||||
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -232,9 +232,9 @@ else:
|
||||
"MultiControlNetModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
"PhotonTransformer2DModel",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"PRXTransformer2DModel",
|
||||
"QwenImageControlNetModel",
|
||||
"QwenImageMultiControlNetModel",
|
||||
"QwenImageTransformer2DModel",
|
||||
@@ -516,11 +516,11 @@ else:
|
||||
"MusicLDMPipeline",
|
||||
"OmniGenPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"PhotonPipeline",
|
||||
"PIAPipeline",
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"PRXPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
@@ -928,9 +928,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiControlNetModel,
|
||||
OmniGenTransformer2DModel,
|
||||
ParallelConfig,
|
||||
PhotonTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
PRXTransformer2DModel,
|
||||
QwenImageControlNetModel,
|
||||
QwenImageMultiControlNetModel,
|
||||
QwenImageTransformer2DModel,
|
||||
@@ -1182,11 +1182,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MusicLDMPipeline,
|
||||
OmniGenPipeline,
|
||||
PaintByExamplePipeline,
|
||||
PhotonPipeline,
|
||||
PIAPipeline,
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
PRXPipeline,
|
||||
QwenImageControlNetInpaintPipeline,
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
|
||||
@@ -96,7 +96,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
@@ -191,9 +191,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
PhotonTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
PRXTransformer2DModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
|
||||
@@ -32,7 +32,7 @@ if is_torch_available():
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_photon import PhotonTransformer2DModel
|
||||
from .transformer_prx import PRXTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
|
||||
@@ -80,9 +80,9 @@ def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
return xq_out.reshape(*xq.shape).type_as(xq)
|
||||
|
||||
|
||||
class PhotonAttnProcessor2_0:
|
||||
class PRXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention
|
||||
Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
|
||||
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
|
||||
"""
|
||||
|
||||
@@ -91,11 +91,11 @@ class PhotonAttnProcessor2_0:
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
|
||||
raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "PhotonAttention",
|
||||
attn: "PRXAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
@@ -103,10 +103,10 @@ class PhotonAttnProcessor2_0:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply Photon attention using PhotonAttention module.
|
||||
Apply PRX attention using PRXAttention module.
|
||||
|
||||
Args:
|
||||
attn: PhotonAttention module containing projection layers
|
||||
attn: PRXAttention module containing projection layers
|
||||
hidden_states: Image tokens [B, L_img, D]
|
||||
encoder_hidden_states: Text tokens [B, L_txt, D]
|
||||
attention_mask: Boolean mask for text tokens [B, L_txt]
|
||||
@@ -114,7 +114,7 @@ class PhotonAttnProcessor2_0:
|
||||
"""
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
|
||||
raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
|
||||
|
||||
# Project image tokens to Q, K, V
|
||||
img_qkv = attn.img_qkv_proj(hidden_states)
|
||||
@@ -190,14 +190,14 @@ class PhotonAttnProcessor2_0:
|
||||
return attn_output
|
||||
|
||||
|
||||
class PhotonAttention(nn.Module, AttentionModuleMixin):
|
||||
class PRXAttention(nn.Module, AttentionModuleMixin):
|
||||
r"""
|
||||
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
|
||||
Photon's architecture.
|
||||
PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
|
||||
PRX's architecture.
|
||||
"""
|
||||
|
||||
_default_processor_cls = PhotonAttnProcessor2_0
|
||||
_available_processors = [PhotonAttnProcessor2_0]
|
||||
_default_processor_cls = PRXAttnProcessor2_0
|
||||
_available_processors = [PRXAttnProcessor2_0]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -251,7 +251,7 @@ class PhotonAttention(nn.Module, AttentionModuleMixin):
|
||||
|
||||
|
||||
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class PhotonEmbedND(nn.Module):
|
||||
class PRXEmbedND(nn.Module):
|
||||
r"""
|
||||
N-dimensional rotary positional embedding.
|
||||
|
||||
@@ -347,7 +347,7 @@ class Modulation(nn.Module):
|
||||
return tuple(out[:3]), tuple(out[3:])
|
||||
|
||||
|
||||
class PhotonBlock(nn.Module):
|
||||
class PRXBlock(nn.Module):
|
||||
r"""
|
||||
Multimodal transformer block with text–image cross-attention, modulation, and MLP.
|
||||
|
||||
@@ -364,7 +364,7 @@ class PhotonBlock(nn.Module):
|
||||
Attributes:
|
||||
img_pre_norm (`nn.LayerNorm`):
|
||||
Pre-normalization applied to image tokens before attention.
|
||||
attention (`PhotonAttention`):
|
||||
attention (`PRXAttention`):
|
||||
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
|
||||
image and text tokens.
|
||||
post_attention_layernorm (`nn.LayerNorm`):
|
||||
@@ -400,15 +400,15 @@ class PhotonBlock(nn.Module):
|
||||
# Pre-attention normalization for image tokens
|
||||
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
# PhotonAttention module with built-in projections and norms
|
||||
self.attention = PhotonAttention(
|
||||
# PRXAttention module with built-in projections and norms
|
||||
self.attention = PRXAttention(
|
||||
query_dim=hidden_size,
|
||||
heads=num_heads,
|
||||
dim_head=self.head_dim,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
eps=1e-6,
|
||||
processor=PhotonAttnProcessor2_0(),
|
||||
processor=PRXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# mlp
|
||||
@@ -557,7 +557,7 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
|
||||
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
|
||||
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
Transformer-based 2D model for text to image generation.
|
||||
|
||||
@@ -595,7 +595,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
txt_in (`nn.Linear`):
|
||||
Projection layer for text conditioning.
|
||||
blocks (`nn.ModuleList`):
|
||||
Stack of transformer blocks (`PhotonBlock`).
|
||||
Stack of transformer blocks (`PRXBlock`).
|
||||
final_layer (`LastLayer`):
|
||||
Projection layer mapping hidden tokens back to patch outputs.
|
||||
|
||||
@@ -661,14 +661,14 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PhotonBlock(
|
||||
PRXBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
@@ -702,7 +702,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
r"""
|
||||
Forward pass of the PhotonTransformer2DModel.
|
||||
Forward pass of the PRXTransformer2DModel.
|
||||
|
||||
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
|
||||
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
|
||||
@@ -144,7 +144,7 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxKontextInpaintPipeline",
|
||||
]
|
||||
_import_structure["photon"] = ["PhotonPipeline"]
|
||||
_import_structure["prx"] = ["PRXPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
@@ -718,9 +718,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .photon import PhotonPipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .prx import PRXPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetInpaintPipeline,
|
||||
QwenImageControlNetPipeline,
|
||||
|
||||
@@ -12,7 +12,7 @@ from ...utils import (
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]}
|
||||
_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_photon"] = ["PhotonPipeline"]
|
||||
_import_structure["pipeline_prx"] = ["PRXPipeline"]
|
||||
|
||||
# Import T5GemmaEncoder for pipeline loading compatibility
|
||||
try:
|
||||
@@ -44,8 +44,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_output import PhotonPipelineOutput
|
||||
from .pipeline_photon import PhotonPipeline
|
||||
from .pipeline_output import PRXPipelineOutput
|
||||
from .pipeline_prx import PRXPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
@@ -22,9 +22,9 @@ from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhotonPipelineOutput(BaseOutput):
|
||||
class PRXPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Photon pipelines.
|
||||
Output class for PRX pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
@@ -30,9 +30,9 @@ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
|
||||
from diffusers.image_processor import PixArtImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderDC, AutoencoderKL
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
logging,
|
||||
@@ -73,7 +73,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
"""Text preprocessing utility for PhotonPipeline."""
|
||||
"""Text preprocessing utility for PRXPipeline."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize text preprocessor."""
|
||||
@@ -203,34 +203,34 @@ EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import PhotonPipeline
|
||||
>>> from diffusers import PRXPipeline
|
||||
|
||||
>>> # Load pipeline with from_pretrained
|
||||
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft")
|
||||
>>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
|
||||
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
>>> image.save("photon_output.png")
|
||||
>>> image.save("prx_output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class PhotonPipeline(
|
||||
class PRXPipeline(
|
||||
DiffusionPipeline,
|
||||
LoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Photon Transformer.
|
||||
Pipeline for text-to-image generation using PRX Transformer.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
transformer ([`PhotonTransformer2DModel`]):
|
||||
The Photon transformer model to denoise the encoded image latents.
|
||||
transformer ([`PRXTransformer2DModel`]):
|
||||
The PRX transformer model to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
text_encoder ([`T5GemmaEncoder`]):
|
||||
@@ -248,7 +248,7 @@ class PhotonPipeline(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: PhotonTransformer2DModel,
|
||||
transformer: PRXTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder: T5GemmaEncoder,
|
||||
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
|
||||
@@ -257,9 +257,9 @@ class PhotonPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if PhotonTransformer2DModel is None:
|
||||
if PRXTransformer2DModel is None:
|
||||
raise ImportError(
|
||||
"PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed."
|
||||
"PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed."
|
||||
)
|
||||
|
||||
self.text_preprocessor = TextPreprocessor()
|
||||
@@ -567,7 +567,7 @@ class PhotonPipeline(
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple.
|
||||
use_resolution_binning (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
|
||||
@@ -585,9 +585,8 @@ class PhotonPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
[`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
# 0. Set height and width
|
||||
@@ -765,4 +764,4 @@ class PhotonPipeline(
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return PhotonPipelineOutput(images=image)
|
||||
return PRXPipelineOutput(images=image)
|
||||
@@ -1098,21 +1098,6 @@ class ParallelConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PhotonTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1143,6 +1128,21 @@ class PriorTransformer(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PRXTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class QwenImageControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1847,21 +1847,6 @@ class PaintByExamplePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PhotonPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PIAPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1922,6 +1907,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PRXPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageControlNetInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -26,8 +26,8 @@ from ..test_modeling_common import ModelTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = PhotonTransformer2DModel
|
||||
class PRXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = PRXTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -75,7 +75,7 @@ class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"PhotonTransformer2DModel"}
|
||||
expected_set = {"PRXTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5G
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
|
||||
|
||||
from diffusers.models import AutoencoderDC, AutoencoderKL
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
@@ -22,8 +22,8 @@ from ..test_pipelines_common import PipelineTesterMixin
|
||||
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
|
||||
strict=False,
|
||||
)
|
||||
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = PhotonPipeline
|
||||
class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = PRXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
|
||||
test_xformers_attention = False
|
||||
@@ -32,16 +32,16 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Ensure PhotonPipeline has an _execution_device property expected by __call__
|
||||
if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property):
|
||||
# Ensure PRXPipeline has an _execution_device property expected by __call__
|
||||
if not isinstance(getattr(PRXPipeline, "_execution_device", None), property):
|
||||
try:
|
||||
setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
|
||||
setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = PhotonTransformer2DModel(
|
||||
transformer = PRXTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
context_in_dim=8,
|
||||
@@ -129,7 +129,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = PhotonPipeline(**components)
|
||||
pipe = PRXPipeline(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
try:
|
||||
@@ -148,7 +148,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_callback_inputs(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = PhotonPipeline(**components)
|
||||
pipe = PRXPipeline(**components)
|
||||
pipe = pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
try:
|
||||
@@ -157,7 +157,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
@@ -216,7 +216,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
|
||||
|
||||
def test_inference_with_autoencoder_dc(self):
|
||||
"""Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
|
||||
"""Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
@@ -248,7 +248,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
components["vae"] = vae_dc
|
||||
|
||||
pipe = PhotonPipeline(**components)
|
||||
pipe = PRXPipeline(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
Reference in New Issue
Block a user