mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 13:04:15 +08:00
Compare commits
16 Commits
qwen-edit-
...
cache-docs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d76b744ac3 | ||
|
|
94c9613f99 | ||
|
|
b91e8c0d0b | ||
|
|
ac7864624b | ||
|
|
5ffb73d4ae | ||
|
|
4088e8a851 | ||
|
|
d33d9f6715 | ||
|
|
dde8754ba2 | ||
|
|
fbcd3ba6b2 | ||
|
|
d176f61fcf | ||
|
|
354d35adb0 | ||
|
|
544ba677dd | ||
|
|
6f1042e36c | ||
|
|
b26867b628 | ||
|
|
e3f441648c | ||
|
|
c6cfc5ce1d |
@@ -349,6 +349,8 @@
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
title: EasyAnimateTransformer3DModel
|
||||
- local: api/models/flux2_transformer
|
||||
title: Flux2Transformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
@@ -525,6 +527,8 @@
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/flux2
|
||||
title: Flux2
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hidream
|
||||
|
||||
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
### FirstBlockCacheConfig
|
||||
## FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
|
||||
@@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
|
||||
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
> [!TIP]
|
||||
@@ -56,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
|
||||
|
||||
## Flux2LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
19
docs/source/en/api/models/flux2_transformer.md
Normal file
19
docs/source/en/api/models/flux2_transformer.md
Normal file
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Flux2Transformer2DModel
|
||||
|
||||
A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).
|
||||
|
||||
## Flux2Transformer2DModel
|
||||
|
||||
[[autodoc]] Flux2Transformer2DModel
|
||||
33
docs/source/en/api/pipelines/flux2.md
Normal file
33
docs/source/en/api/pipelines/flux2.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Flux2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
|
||||
|
||||
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2).
|
||||
|
||||
> [!TIP]
|
||||
> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
|
||||
>
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## Flux2Pipeline
|
||||
|
||||
[[autodoc]] Flux2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -43,11 +43,13 @@ Note: The recommended dtype mentioned is for the transformer weights. The text e
|
||||
<hfoptions id="generation pipelines">`
|
||||
<hfoption id="Text-to-Video">
|
||||
|
||||
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame.
|
||||
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text description.
|
||||
|
||||
```python
|
||||
model_id =
|
||||
pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe = SanaVideoPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
pipe.vae.to(torch.float32)
|
||||
pipe.to("cuda")
|
||||
@@ -75,12 +77,11 @@ export_to_video(video, "sana_video.mp4", fps=16)
|
||||
</hfoption>
|
||||
<hfoption id="Image-to-Video">
|
||||
|
||||
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame.
|
||||
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame.
|
||||
|
||||
```python
|
||||
model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
|
||||
pipe = SanaImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
|
||||
|
||||
@@ -139,12 +139,14 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
|
||||
|
||||
@@ -66,4 +66,8 @@ config = FasterCacheConfig(
|
||||
tensor_format="BFCHW",
|
||||
)
|
||||
pipeline.transformer.enable_cache(config)
|
||||
```
|
||||
```
|
||||
|
||||
## FirstBlockCache
|
||||
|
||||
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler to implement generically for a wide range of models and has been integrated first for experimental purposes.
|
||||
@@ -1,8 +1,10 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: 🧨 Diffusers
|
||||
- local: quicktour
|
||||
title: Tour rápido
|
||||
- local: installation
|
||||
title: Instalação
|
||||
- local: index
|
||||
title: Diffusers
|
||||
- local: installation
|
||||
title: Instalação
|
||||
- local: quicktour
|
||||
title: Tour rápido
|
||||
- local: stable_diffusion
|
||||
title: Desempenho básico
|
||||
title: Primeiros passos
|
||||
|
||||
@@ -18,11 +18,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Diffusers
|
||||
|
||||
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou queira treinar seu próprio modelo de difusão, 🤗 Diffusers é uma modular caixa de ferramentas que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
|
||||
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou quer treinar seu próprio modelo de difusão, 🤗 Diffusers é uma caixa de ferramentas modular que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
|
||||
|
||||
A Biblioteca tem três componentes principais:
|
||||
|
||||
- Pipelines de última geração para a geração em poucas linhas de código. Têm muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
|
||||
- Pipelines de última geração para a geração em poucas linhas de código. Há muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
|
||||
- Intercambiáveis [agendadores de ruído](api/schedulers/overview) para balancear as compensações entre velocidade e qualidade de geração.
|
||||
- [Modelos](api/models) pré-treinados que podem ser usados como se fossem blocos de construção, e combinados com agendadores, para criar seu próprio sistema de difusão de ponta a ponta.
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Recomenda-se instalar 🤗 Diffusers em um [ambiente virtual](https://docs.python.org/3/library/venv.html).
|
||||
Se você não está familiarizado com ambiente virtuals, veja o [guia](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
Um ambiente virtual deixa mais fácil gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
|
||||
Um ambiente virtual facilita gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
|
||||
|
||||
Comece criando um ambiente virtual no diretório do projeto:
|
||||
|
||||
@@ -100,12 +100,12 @@ pip install -e ".[flax]"
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
Esses comandos irá linkar a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
|
||||
Esses comandos irão vincular a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
|
||||
Python então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.
|
||||
Por exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.10/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.
|
||||
|
||||
> [!WARNING]
|
||||
> Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
|
||||
> Você deve manter a pasta `diffusers` se quiser continuar usando a biblioteca.
|
||||
|
||||
Agora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:
|
||||
|
||||
|
||||
132
docs/source/pt/stable_diffusion.md
Normal file
132
docs/source/pt/stable_diffusion.md
Normal file
@@ -0,0 +1,132 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Desempenho básico
|
||||
|
||||
Difusão é um processo aleatório que demanda muito processamento. Você pode precisar executar o [`DiffusionPipeline`] várias vezes antes de obter o resultado desejado. Por isso é importante equilibrar cuidadosamente a velocidade de geração e o uso de memória para iterar mais rápido.
|
||||
|
||||
Este guia recomenda algumas dicas básicas de desempenho para usar o [`DiffusionPipeline`]. Consulte a seção de documentação sobre Otimização de Inferência, como [Acelerar inferência](./optimization/fp16) ou [Reduzir uso de memória](./optimization/memory) para guias de desempenho mais detalhados.
|
||||
|
||||
## Uso de memória
|
||||
|
||||
Reduzir a quantidade de memória usada indiretamente acelera a geração e pode ajudar um modelo a caber no dispositivo.
|
||||
|
||||
O método [`~DiffusionPipeline.enable_model_cpu_offload`] move um modelo para a CPU quando não está em uso para economizar memória da GPU.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
pipeline(prompt).images[0]
|
||||
print(f"Memória máxima reservada: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
|
||||
```
|
||||
|
||||
## Velocidade de inferência
|
||||
|
||||
O processo de remoção de ruído é o mais exigente computacionalmente durante a difusão. Métodos que otimizam este processo aceleram a velocidade de inferência. Experimente os seguintes métodos para acelerar.
|
||||
|
||||
- Adicione `device_map="cuda"` para colocar o pipeline em uma GPU. Colocar um modelo em um acelerador, como uma GPU, aumenta a velocidade porque realiza computações em paralelo.
|
||||
- Defina `torch_dtype=torch.bfloat16` para executar o pipeline em meia-precisão. Reduzir a precisão do tipo de dado aumenta a velocidade porque leva menos tempo para realizar computações em precisão mais baixa.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import time
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
- Use um agendador mais rápido, como [`DPMSolverMultistepScheduler`], que requer apenas ~20-25 passos.
|
||||
- Defina `num_inference_steps` para um valor menor. Reduzir o número de passos de inferência reduz o número total de computações. No entanto, isso pode resultar em menor qualidade de geração.
|
||||
|
||||
```py
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
image = pipeline(prompt).images[0]
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"Geração de imagem levou {end_time - start_time:.3f} segundos")
|
||||
```
|
||||
|
||||
## Qualidade de geração
|
||||
|
||||
Muitos modelos de difusão modernos entregam imagens de alta qualidade imediatamente. No entanto, você ainda pode melhorar a qualidade de geração experimentando o seguinte.
|
||||
|
||||
- Experimente um prompt mais detalhado e descritivo. Inclua detalhes como o meio da imagem, assunto, estilo e estética. Um prompt negativo também pode ajudar, guiando um modelo para longe de características indesejáveis usando palavras como baixa qualidade ou desfocado.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
negative_prompt = "low quality, blurry, ugly, poor details"
|
||||
pipeline(prompt, negative_prompt=negative_prompt).images[0]
|
||||
```
|
||||
|
||||
Para mais detalhes sobre como criar prompts melhores, consulte a documentação sobre [Técnicas de prompt](./using-diffusers/weighted_prompts).
|
||||
|
||||
- Experimente um agendador diferente, como [`HeunDiscreteScheduler`] ou [`LMSDiscreteScheduler`], que sacrifica velocidade de geração por qualidade.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, HeunDiscreteScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
negative_prompt = "low quality, blurry, ugly, poor details"
|
||||
pipeline(prompt, negative_prompt=negative_prompt).images[0]
|
||||
```
|
||||
|
||||
## Próximos passos
|
||||
|
||||
Diffusers oferece otimizações mais avançadas e poderosas, como [group-offloading](./optimization/memory#group-offloading) e [compilação regional](./optimization/fp16#regional-compilation). Para saber mais sobre como maximizar o desempenho, consulte a seção sobre Otimização de Inferência.
|
||||
@@ -1168,12 +1168,12 @@ class FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin,
|
||||
generator,
|
||||
)
|
||||
|
||||
mask_imagee = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
masked_imagee = init_image * (1 - mask_imagee)
|
||||
masked_imagee = masked_imagee.to(dtype=self.vae.dtype, device=device)
|
||||
maskkk, masked_image_latentsss = self.prepare_mask_latents_fill(
|
||||
mask_imagee,
|
||||
masked_imagee,
|
||||
mask_image_fill = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
masked_image_fill = init_image * (1 - mask_image_fill)
|
||||
masked_image_fill = masked_image_fill.to(dtype=self.vae.dtype, device=device)
|
||||
mask_fill, masked_latents_fill = self.prepare_mask_latents_fill(
|
||||
mask_image_fill,
|
||||
masked_image_fill,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_images_per_prompt,
|
||||
@@ -1243,7 +1243,7 @@ class FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin,
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
masked_image_latents_fill = torch.cat((masked_image_latentsss, maskkk), dim=-1)
|
||||
masked_image_latents_fill = torch.cat((masked_latents_fill, mask_fill), dim=-1)
|
||||
latent_model_input = torch.cat([latents, masked_image_latents_fill], dim=2)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
|
||||
315
examples/dreambooth/README_flux2.md
Normal file
315
examples/dreambooth/README_flux2.md
Normal file
@@ -0,0 +1,315 @@
|
||||
# DreamBooth training example for FLUX.2 [dev]
|
||||
|
||||
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
|
||||
|
||||
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
|
||||
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
|
||||
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
|
||||
|
||||
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
|
||||
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
|
||||
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training)
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_flux.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
|
||||
|
||||
## Memory Optimizations
|
||||
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
|
||||
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
|
||||
### Remote Text Encoder
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
### Latent Caching
|
||||
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
|
||||
### QLoRA: Low Precision Training with Quantization
|
||||
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
|
||||
- **FP8 training** with `torchao`:
|
||||
enable FP8 training by passing `--do_fp8_training`.
|
||||
> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater.
|
||||
> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc.
|
||||
- **NF4 training** with `bitsandbytes`:
|
||||
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
|
||||
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
|
||||
### Gradient Checkpointing and Accumulation
|
||||
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
|
||||
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
|
||||
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
|
||||
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
|
||||
### 8-bit-Adam Optimizer
|
||||
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
|
||||
Make sure to install `bitsandbytes` if you want to do so.
|
||||
### Image Resolution
|
||||
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
|
||||
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
|
||||
### Precision of saved LoRA layers
|
||||
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
|
||||
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
|
||||
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux2"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux2.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--use_8bit_adam \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="adamW" \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Prodigy Optimizer
|
||||
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
|
||||
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
|
||||
|
||||
to use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -
|
||||
```bash
|
||||
--optimizer="prodigy"
|
||||
```
|
||||
> [!TIP]
|
||||
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
|
||||
|
||||
To perform DreamBooth with LoRA, run:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux2-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux2.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="prodigy" \
|
||||
--learning_rate=1. \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
|
||||
|
||||
## Training Image-to-Image
|
||||
|
||||
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
|
||||
|
||||
**important**
|
||||
|
||||
**Important**
|
||||
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
|
||||
To start, you must have a dataset containing triplets:
|
||||
|
||||
* Condition image - the input image to be transformed.
|
||||
* Target image - the desired output image after transformation.
|
||||
* Instruction - a text prompt describing the transformation from the condition image to the target image.
|
||||
|
||||
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_flux2_img2img.py \
|
||||
--pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev \
|
||||
--output_dir="flux2-i2i" \
|
||||
--dataset_name="kontext-community/relighting" \
|
||||
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--optimizer="adamw" \
|
||||
--use_8bit_adam \
|
||||
--cache_latents \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=200 \
|
||||
--max_train_steps=1000 \
|
||||
--rank=16\
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
More generally, when performing I2I fine-tuning, we expect you to:
|
||||
|
||||
* Have a dataset `kontext-community/relighting`
|
||||
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
|
||||
|
||||
### Misc notes
|
||||
|
||||
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
|
||||
### Aspect Ratio Bucketing
|
||||
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
|
||||
|
||||
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
|
||||
|
||||
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
|
||||
`
|
||||
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
262
examples/dreambooth/test_dreambooth_lora_flux2.py
Normal file
262
examples/dreambooth/test_dreambooth_lora_flux2.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAFlux2(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "dog"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_flux2.py"
|
||||
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
|
||||
|
||||
def test_dreambooth_lora_flux2(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 8
|
||||
--checkpointing_steps=2
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
1914
examples/dreambooth/train_dreambooth_lora_flux2.py
Normal file
1914
examples/dreambooth/train_dreambooth_lora_flux2.py
Normal file
File diff suppressed because it is too large
Load Diff
1831
examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
Normal file
1831
examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
Normal file
File diff suppressed because it is too large
Load Diff
475
scripts/convert_flux2_to_diffusers.py
Normal file
475
scripts/convert_flux2_to_diffusers.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration
|
||||
|
||||
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
"""
|
||||
# VAE
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \
|
||||
--vae_filename "flux2-vae.sft" \
|
||||
--output_path "/raid/yiyi/dummy-flux2-diffusers" \
|
||||
--vae
|
||||
|
||||
# DiT
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
|
||||
--dit_filename flux-dev-dummy.sft \
|
||||
--dit \
|
||||
--output_path .
|
||||
|
||||
# Full pipe
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
|
||||
--dit_filename flux-dev-dummy.sft \
|
||||
--vae_filename "flux2-vae.sft" \
|
||||
--dit --vae --full_pipe \
|
||||
--output_path .
|
||||
"""
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
||||
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
|
||||
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
|
||||
parser.add_argument("--vae", action="store_true")
|
||||
parser.add_argument("--dit", action="store_true")
|
||||
parser.add_argument("--vae_dtype", type=str, default="fp32")
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--full_pipe", action="store_true")
|
||||
parser.add_argument("--output_path", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
DIFFUSERS_VAE_TO_FLUX2_MAPPING = {
|
||||
"encoder.conv_in.weight": "encoder.conv_in.weight",
|
||||
"encoder.conv_in.bias": "encoder.conv_in.bias",
|
||||
"encoder.conv_out.weight": "encoder.conv_out.weight",
|
||||
"encoder.conv_out.bias": "encoder.conv_out.bias",
|
||||
"encoder.conv_norm_out.weight": "encoder.norm_out.weight",
|
||||
"encoder.conv_norm_out.bias": "encoder.norm_out.bias",
|
||||
"decoder.conv_in.weight": "decoder.conv_in.weight",
|
||||
"decoder.conv_in.bias": "decoder.conv_in.bias",
|
||||
"decoder.conv_out.weight": "decoder.conv_out.weight",
|
||||
"decoder.conv_out.bias": "decoder.conv_out.bias",
|
||||
"decoder.conv_norm_out.weight": "decoder.norm_out.weight",
|
||||
"decoder.conv_norm_out.bias": "decoder.norm_out.bias",
|
||||
"quant_conv.weight": "encoder.quant_conv.weight",
|
||||
"quant_conv.bias": "encoder.quant_conv.bias",
|
||||
"post_quant_conv.weight": "decoder.post_quant_conv.weight",
|
||||
"post_quant_conv.bias": "decoder.post_quant_conv.bias",
|
||||
"bn.running_mean": "bn.running_mean",
|
||||
"bn.running_var": "bn.running_var",
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
||||
for ldm_key in keys:
|
||||
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
|
||||
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
||||
|
||||
|
||||
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
||||
for ldm_key in keys:
|
||||
diffusers_key = (
|
||||
ldm_key.replace(mapping["old"], mapping["new"])
|
||||
.replace("norm.weight", "group_norm.weight")
|
||||
.replace("norm.bias", "group_norm.bias")
|
||||
.replace("q.weight", "to_q.weight")
|
||||
.replace("q.bias", "to_q.bias")
|
||||
.replace("k.weight", "to_k.weight")
|
||||
.replace("k.bias", "to_k.bias")
|
||||
.replace("v.weight", "to_v.weight")
|
||||
.replace("v.bias", "to_v.bias")
|
||||
.replace("proj_out.weight", "to_out.0.weight")
|
||||
.replace("proj_out.bias", "to_out.0.bias")
|
||||
)
|
||||
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
shape = new_checkpoint[diffusers_key].shape
|
||||
|
||||
if len(shape) == 3:
|
||||
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
|
||||
elif len(shape) == 4:
|
||||
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
|
||||
|
||||
|
||||
def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config):
|
||||
new_checkpoint = {}
|
||||
for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items():
|
||||
if ldm_key not in vae_state_dict:
|
||||
continue
|
||||
new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len(config["down_block_types"])
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
|
||||
)
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
)
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
|
||||
)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||
update_vae_attentions_ldm_to_diffusers(
|
||||
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
)
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len(config["up_block_types"])
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
|
||||
)
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||
]
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
|
||||
)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||
update_vae_attentions_ldm_to_diffusers(
|
||||
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Image and text input projections
|
||||
"img_in": "x_embedder",
|
||||
"txt_in": "context_embedder",
|
||||
# Timestep and guidance embeddings
|
||||
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
|
||||
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
|
||||
# Modulation parameters
|
||||
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
|
||||
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
|
||||
"single_stream_modulation.lin": "single_stream_modulation.linear",
|
||||
# Final output layer
|
||||
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
|
||||
# Handle fused QKV projections separately as we need to break into Q, K, V projections
|
||||
"img_attn.norm.query_norm": "attn.norm_q",
|
||||
"img_attn.norm.key_norm": "attn.norm_k",
|
||||
"img_attn.proj": "attn.to_out.0",
|
||||
"img_mlp.0": "ff.linear_in",
|
||||
"img_mlp.2": "ff.linear_out",
|
||||
"txt_attn.norm.query_norm": "attn.norm_added_q",
|
||||
"txt_attn.norm.key_norm": "attn.norm_added_k",
|
||||
"txt_attn.proj": "attn.to_add_out",
|
||||
"txt_mlp.0": "ff_context.linear_in",
|
||||
"txt_mlp.2": "ff_context.linear_out",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
|
||||
"linear1": "attn.to_qkv_mlp_proj",
|
||||
"linear2": "attn.to_out",
|
||||
"norm.query_norm": "attn.norm_q",
|
||||
"norm.key_norm": "attn.norm_k",
|
||||
}
|
||||
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
|
||||
# diffusers implementation
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight
|
||||
if ".weight" not in key:
|
||||
return
|
||||
|
||||
# If adaLN_modulation is in the key, swap scale and shift parameters
|
||||
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
|
||||
if "adaLN_modulation" in key:
|
||||
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
|
||||
# Assume all such keys are in the AdaLayerNorm key map
|
||||
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
|
||||
new_key = ".".join([new_key_without_param_type, param_type])
|
||||
|
||||
swapped_weight = swap_scale_shift(state_dict.pop(key))
|
||||
state_dict[new_key] = swapped_weight
|
||||
return
|
||||
|
||||
|
||||
def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
new_prefix = "transformer_blocks"
|
||||
if "double_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
if "qkv" in within_block_name:
|
||||
fused_qkv_weight = state_dict.pop(key)
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
if "img" in modality_block_name:
|
||||
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.to_q"
|
||||
new_k_name = "attn.to_k"
|
||||
new_v_name = "attn.to_v"
|
||||
elif "txt" in modality_block_name:
|
||||
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.add_q_proj"
|
||||
new_k_name = "attn.add_k_proj"
|
||||
new_v_name = "attn.add_v_proj"
|
||||
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
|
||||
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
|
||||
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
|
||||
state_dict[new_q_key] = to_q_weight
|
||||
state_dict[new_k_key] = to_k_weight
|
||||
state_dict[new_v_key] = to_v_weight
|
||||
else:
|
||||
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
# Mapping:
|
||||
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
|
||||
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
|
||||
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
|
||||
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
|
||||
new_prefix = "single_transformer_blocks"
|
||||
if "single_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"adaLN_modulation": convert_ada_layer_norm_weights,
|
||||
"double_blocks": convert_flux2_double_stream_blocks,
|
||||
"single_blocks": convert_flux2_single_stream_blocks,
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
if model_type == "test" or model_type == "dummy-flux2":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-flux2",
|
||||
"diffusers_config": {
|
||||
"patch_size": 1,
|
||||
"in_channels": 128,
|
||||
"num_layers": 8,
|
||||
"num_single_layers": 48,
|
||||
"attention_head_dim": 128,
|
||||
"num_attention_heads": 48,
|
||||
"joint_attention_dim": 15360,
|
||||
"timestep_guidance_channels": 256,
|
||||
"mlp_ratio": 3.0,
|
||||
"axes_dims_rope": (32, 32, 32, 32),
|
||||
"rope_theta": 2000,
|
||||
"eps": 1e-6,
|
||||
},
|
||||
}
|
||||
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
|
||||
config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
|
||||
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = Flux2Transformer2DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.vae:
|
||||
original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename)
|
||||
vae = AutoencoderKLFlux2()
|
||||
converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config)
|
||||
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
||||
if not args.full_pipe:
|
||||
vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
|
||||
vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae")
|
||||
|
||||
if args.dit:
|
||||
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
|
||||
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
|
||||
if not args.full_pipe:
|
||||
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
|
||||
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
|
||||
|
||||
if args.full_pipe:
|
||||
tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
|
||||
generate_config = GenerationConfig.from_pretrained(text_encoder_id)
|
||||
generate_config.do_sample = True
|
||||
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = AutoProcessor.from_pretrained(tokenizer_id)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
|
||||
)
|
||||
|
||||
pipe = Flux2Pipeline(
|
||||
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
|
||||
)
|
||||
pipe.save_pretrained(args.output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -186,6 +186,7 @@ else:
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLCosmos",
|
||||
"AutoencoderKLFlux2",
|
||||
"AutoencoderKLHunyuanImage",
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
@@ -215,6 +216,7 @@ else:
|
||||
"CosmosTransformer3DModel",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
"Flux2Transformer2DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
@@ -271,6 +273,7 @@ else:
|
||||
"WanAnimateTransformer3DModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"ZImageTransformer2DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
)
|
||||
@@ -457,6 +460,7 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"Flux2Pipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
@@ -647,6 +651,7 @@ else:
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -900,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLFlux2,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
@@ -929,6 +935,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
@@ -1141,6 +1148,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
Flux2Pipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
@@ -1329,6 +1337,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -111,6 +111,7 @@ def _register_attention_processors_metadata():
|
||||
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
|
||||
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
||||
from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
|
||||
|
||||
# AttnProcessor2_0
|
||||
AttentionProcessorRegistry.register(
|
||||
@@ -158,6 +159,14 @@ def _register_attention_processors_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# ZSingleStreamAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=ZSingleStreamAttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
@@ -179,6 +188,7 @@ def _register_transformer_blocks_metadata():
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
|
||||
|
||||
# BasicTransformerBlock
|
||||
TransformerBlockRegistry.register(
|
||||
@@ -312,6 +322,15 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# ZImage
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=ZImageTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
@@ -338,4 +357,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
|
||||
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
@@ -81,6 +81,7 @@ if is_torch_available():
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
"QwenImageLoraLoaderMixin",
|
||||
"Flux2LoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
@@ -113,6 +114,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AuraFlowLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
Flux2LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
|
||||
@@ -2265,3 +2265,89 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
|
||||
prefix = "diffusion_model."
|
||||
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_double_layers = 8
|
||||
num_single_layers = 48
|
||||
lora_keys = ("lora_A", "lora_B")
|
||||
attn_types = ("img_attn", "txt_attn")
|
||||
|
||||
for sl in range(num_single_layers):
|
||||
single_block_prefix = f"single_blocks.{sl}"
|
||||
attn_prefix = f"single_transformer_blocks.{sl}.attn"
|
||||
|
||||
for lora_key in lora_keys:
|
||||
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{single_block_prefix}.linear1.{lora_key}.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{single_block_prefix}.linear2.{lora_key}.weight"
|
||||
)
|
||||
|
||||
for dl in range(num_double_layers):
|
||||
transformer_block_prefix = f"transformer_blocks.{dl}"
|
||||
|
||||
for lora_key in lora_keys:
|
||||
for attn_type in attn_types:
|
||||
attn_prefix = f"{transformer_block_prefix}.attn"
|
||||
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
|
||||
fused_qkv_weight = original_state_dict.pop(qkv_key)
|
||||
|
||||
if lora_key == "lora_A":
|
||||
diff_attn_proj_keys = (
|
||||
["to_q", "to_k", "to_v"]
|
||||
if attn_type == "img_attn"
|
||||
else ["add_q_proj", "add_k_proj", "add_v_proj"]
|
||||
)
|
||||
for proj_key in diff_attn_proj_keys:
|
||||
converted_state_dict[f"{attn_prefix}.{proj_key}.{lora_key}.weight"] = torch.cat(
|
||||
[fused_qkv_weight]
|
||||
)
|
||||
else:
|
||||
sample_q, sample_k, sample_v = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
|
||||
if attn_type == "img_attn":
|
||||
converted_state_dict[f"{attn_prefix}.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{attn_prefix}.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{attn_prefix}.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
else:
|
||||
converted_state_dict[f"{attn_prefix}.add_q_proj.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{attn_prefix}.add_k_proj.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{attn_prefix}.add_v_proj.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
|
||||
proj_mappings = [
|
||||
("img_attn.proj", "attn.to_out.0"),
|
||||
("txt_attn.proj", "attn.to_add_out"),
|
||||
]
|
||||
for org_proj, diff_proj in proj_mappings:
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
|
||||
mlp_mappings = [
|
||||
("img_mlp.0", "ff.linear_in"),
|
||||
("img_mlp.2", "ff.linear_out"),
|
||||
("txt_mlp.0", "ff_context.linear_in"),
|
||||
("txt_mlp.2", "ff_context.linear_out"),
|
||||
]
|
||||
for org_mlp, diff_mlp in mlp_mappings:
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -45,6 +45,7 @@ from .lora_conversion_utils import (
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
_convert_non_diffusers_flux2_lora_to_diffusers,
|
||||
_convert_non_diffusers_hidream_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
@@ -5084,6 +5085,209 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if is_ai_toolkit:
|
||||
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -62,6 +62,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from .single_file_utils import (
|
||||
convert_chroma_transformer_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
convert_flux2_transformer_checkpoint_to_diffusers,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
@@ -162,6 +163,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": lambda x: x,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Flux2Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -140,6 +140,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"net.blocks.0.self_attn.q_proj.weight",
|
||||
"net.pos_embedder.dim_spatial_range",
|
||||
],
|
||||
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -189,6 +190,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
|
||||
@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "animatediff_v3"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]):
|
||||
model_type = "flux-2-dev"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
@@ -3647,3 +3652,168 @@ def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Image and text input projections
|
||||
"img_in": "x_embedder",
|
||||
"txt_in": "context_embedder",
|
||||
# Timestep and guidance embeddings
|
||||
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
|
||||
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
|
||||
# Modulation parameters
|
||||
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
|
||||
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
|
||||
"single_stream_modulation.lin": "single_stream_modulation.linear",
|
||||
# Final output layer
|
||||
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
|
||||
# Handle fused QKV projections separately as we need to break into Q, K, V projections
|
||||
"img_attn.norm.query_norm": "attn.norm_q",
|
||||
"img_attn.norm.key_norm": "attn.norm_k",
|
||||
"img_attn.proj": "attn.to_out.0",
|
||||
"img_mlp.0": "ff.linear_in",
|
||||
"img_mlp.2": "ff.linear_out",
|
||||
"txt_attn.norm.query_norm": "attn.norm_added_q",
|
||||
"txt_attn.norm.key_norm": "attn.norm_added_k",
|
||||
"txt_attn.proj": "attn.to_add_out",
|
||||
"txt_mlp.0": "ff_context.linear_in",
|
||||
"txt_mlp.2": "ff_context.linear_out",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
|
||||
"linear1": "attn.to_qkv_mlp_proj",
|
||||
"linear2": "attn.to_out",
|
||||
"norm.query_norm": "attn.norm_q",
|
||||
"norm.key_norm": "attn.norm_k",
|
||||
}
|
||||
|
||||
def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
# Mapping:
|
||||
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
|
||||
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
|
||||
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
|
||||
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
|
||||
new_prefix = "single_transformer_blocks"
|
||||
if "single_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight
|
||||
if ".weight" not in key:
|
||||
return
|
||||
|
||||
# If adaLN_modulation is in the key, swap scale and shift parameters
|
||||
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
|
||||
if "adaLN_modulation" in key:
|
||||
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
|
||||
# Assume all such keys are in the AdaLayerNorm key map
|
||||
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
|
||||
new_key = ".".join([new_key_without_param_type, param_type])
|
||||
|
||||
swapped_weight = swap_scale_shift(state_dict.pop(key), 0)
|
||||
state_dict[new_key] = swapped_weight
|
||||
|
||||
return
|
||||
|
||||
def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
new_prefix = "transformer_blocks"
|
||||
if "double_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
if "qkv" in within_block_name:
|
||||
fused_qkv_weight = state_dict.pop(key)
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
if "img" in modality_block_name:
|
||||
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.to_q"
|
||||
new_k_name = "attn.to_k"
|
||||
new_v_name = "attn.to_v"
|
||||
elif "txt" in modality_block_name:
|
||||
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.add_q_proj"
|
||||
new_k_name = "attn.add_k_proj"
|
||||
new_v_name = "attn.add_v_proj"
|
||||
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
|
||||
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
|
||||
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
|
||||
state_dict[new_q_key] = to_q_weight
|
||||
state_dict[new_k_key] = to_k_weight
|
||||
state_dict[new_v_key] = to_v_weight
|
||||
else:
|
||||
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"adaLN_modulation": convert_ada_layer_norm_weights,
|
||||
"double_blocks": convert_flux2_double_stream_blocks,
|
||||
"single_blocks": convert_flux2_single_stream_blocks,
|
||||
}
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict(converted_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
|
||||
_import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
@@ -92,6 +93,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
@@ -110,6 +112,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
||||
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -140,6 +143,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLFlux2,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
@@ -190,6 +194,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -218,6 +223,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanAnimateTransformer3DModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
@@ -105,7 +105,7 @@ class AttentionMixin:
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
|
||||
module.fuse_projections()
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
@@ -114,13 +114,14 @@ class AttentionMixin:
|
||||
> [!WARNING] > This API is 🧪 experimental.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
|
||||
module.unfuse_projections()
|
||||
|
||||
|
||||
class AttentionModuleMixin:
|
||||
_default_processor_cls = None
|
||||
_available_processors = []
|
||||
_supports_qkv_fusion = True
|
||||
fused_projections = False
|
||||
|
||||
def set_processor(self, processor: AttentionProcessor) -> None:
|
||||
@@ -248,6 +249,14 @@ class AttentionModuleMixin:
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
"""
|
||||
# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
|
||||
# single stream blocks are always fused)
|
||||
if not self._supports_qkv_fusion:
|
||||
logger.debug(
|
||||
f"{self.__class__.__name__} does not support fusing QKV projections, so `fuse_projections` will no-op."
|
||||
)
|
||||
return
|
||||
|
||||
# Skip if already fused
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
@@ -307,6 +316,11 @@ class AttentionModuleMixin:
|
||||
"""
|
||||
Unfuse the query, key, and value projections back to separate projections.
|
||||
"""
|
||||
# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
|
||||
# single stream blocks are always fused)
|
||||
if not self._supports_qkv_fusion:
|
||||
return
|
||||
|
||||
# Skip if not fused
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
@@ -18,7 +18,7 @@ import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -160,16 +160,13 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
# - CP with sage attention, flex, xformers, other missing backends
|
||||
# - Add support for normal and CP training with backends that don't support it yet
|
||||
|
||||
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
||||
|
||||
|
||||
class AttentionBackendName(str, Enum):
|
||||
# EAGER = "eager"
|
||||
|
||||
# `flash-attn`
|
||||
FLASH = "flash"
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
@@ -191,6 +188,7 @@ class AttentionBackendName(str, Enum):
|
||||
|
||||
# `sageattention`
|
||||
SAGE = "sage"
|
||||
SAGE_HUB = "sage_hub"
|
||||
SAGE_VARLEN = "sage_varlen"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
||||
@@ -264,7 +262,13 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
)
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
),
|
||||
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -420,8 +424,8 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
# TODO: add support Hub variant of varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
@@ -1350,6 +1354,38 @@ def _flash_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -1431,6 +1467,7 @@ def _flash_attention_3(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -1444,6 +1481,9 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
@@ -1938,6 +1978,38 @@ def _sage_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -4,6 +4,7 @@ from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
|
||||
from .autoencoder_kl_flux2 import AutoencoderKLFlux2
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
546
src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
Normal file
546
src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
Normal file
@@ -0,0 +1,546 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
FusedAttnProcessor2_0,
|
||||
)
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
class AutoencoderKLFlux2(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
force_upcast (`bool`, *optional*, default to `True`):
|
||||
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
||||
can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
|
||||
can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
||||
mid_block_add_attention (`bool`, *optional*, default to `True`):
|
||||
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
|
||||
mid_block will only have resnet blocks
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
512,
|
||||
),
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 32,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 1024, # YiYi notes: not sure
|
||||
force_upcast: bool = True,
|
||||
use_quant_conv: bool = True,
|
||||
use_post_quant_conv: bool = True,
|
||||
mid_block_add_attention: bool = True,
|
||||
batch_norm_eps: float = 1e-4,
|
||||
batch_norm_momentum: float = 0.1,
|
||||
patch_size: Tuple[int, int] = (2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=True,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||
|
||||
self.bn = nn.BatchNorm2d(
|
||||
math.prod(patch_size) * latent_channels,
|
||||
eps=batch_norm_eps,
|
||||
momentum=batch_norm_momentum,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self._tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
enc = self.quant_conv(enc)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
if self.config.use_quant_conv:
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
enc = torch.cat(result_rows, dim=2)
|
||||
return enc
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||
`tuple` is returned.
|
||||
"""
|
||||
deprecation_message = (
|
||||
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
|
||||
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
|
||||
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
|
||||
)
|
||||
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
if self.config.use_quant_conv:
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
moments = torch.cat(result_rows, dim=2)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, z.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
if self.config.use_post_quant_conv:
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
dec = torch.cat(result_rows, dim=2)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
> [!WARNING] > This API is 🧪 experimental.
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
> [!WARNING] > This API is 🧪 experimental.
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
@@ -41,9 +41,11 @@ class CacheMixin:
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
||||
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
- [`~hooks.FasterCacheConfig`]
|
||||
- [`~hooks.FirstBlockCacheConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ if is_torch_available():
|
||||
from .transformer_cosmos import CosmosTransformer3DModel
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
||||
@@ -44,3 +45,4 @@ if is_torch_available():
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
from .transformer_wan_animate import WanAnimateTransformer3DModel
|
||||
from .transformer_wan_vace import WanVACETransformer3DModel
|
||||
from .transformer_z_image import ZImageTransformer2DModel
|
||||
|
||||
@@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
|
||||
return key_img, value_img
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
|
||||
# modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor
|
||||
class WanAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
@@ -137,7 +137,8 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12660
|
||||
parallel_config=None,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -150,7 +151,8 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12660
|
||||
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
@@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel(
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12660
|
||||
# We need to disable the splitting of encoder_hidden_states because
|
||||
# the image_encoder consistently generates 257 tokens for image_embed. This causes
|
||||
# the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
|
||||
# after concatenation—to be indivisible by the number of devices in the CP.
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
|
||||
908
src/diffusers/models/transformers/transformer_flux2.py
Normal file
908
src/diffusers/models/transformers/transformer_flux2.py
Normal file
@@ -0,0 +1,908 @@
|
||||
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = None
|
||||
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = (None,)
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
if attn.fused_projections:
|
||||
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
class Flux2SwiGLU(nn.Module):
|
||||
"""
|
||||
Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
|
||||
layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = self.gate_fn(x1) * x2
|
||||
return x
|
||||
|
||||
|
||||
class Flux2FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: float = 3.0,
|
||||
inner_dim: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out or dim
|
||||
|
||||
# Flux2SwiGLU will reduce the dimension by half
|
||||
self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
|
||||
self.act_fn = Flux2SwiGLU()
|
||||
self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_in(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.linear_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Flux2AttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = Flux2AttnProcessor
|
||||
_available_processors = [Flux2AttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
# QK Norm
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if added_kv_proj_dim is not None:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2ParallelSelfAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Parallel in (QKV + MLP in) projection
|
||||
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
||||
qkv, mlp_hidden_states = torch.split(
|
||||
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
||||
)
|
||||
|
||||
# Handle the attention logic
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Handle the feedforward (FF) logic
|
||||
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
||||
|
||||
# Concatenate and parallel output projection
|
||||
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
||||
|
||||
This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
|
||||
input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
|
||||
paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
|
||||
"""
|
||||
|
||||
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
||||
# Does not support QKV fusion as the QKV projections are always fused
|
||||
_supports_qkv_fusion = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
elementwise_affine: bool = True,
|
||||
mlp_ratio: float = 4.0,
|
||||
mlp_mult_factor: int = 2,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
|
||||
self.mlp_mult_factor = mlp_mult_factor
|
||||
|
||||
# Fused QKV projections + MLP input projection
|
||||
self.to_qkv_mlp_proj = torch.nn.Linear(
|
||||
self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
|
||||
)
|
||||
self.mlp_act_fn = Flux2SwiGLU()
|
||||
|
||||
# QK Norm
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
# Fused attention output projection + MLP output projection
|
||||
self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class Flux2SingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 3.0,
|
||||
eps: float = 1e-6,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
|
||||
# Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
|
||||
# is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
|
||||
# for a visual depiction of this type of transformer block.
|
||||
self.attn = Flux2ParallelSelfAttention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=bias,
|
||||
out_bias=bias,
|
||||
eps=eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
mlp_mult_factor=2,
|
||||
processor=Flux2ParallelSelfAttnProcessor(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
split_hidden_states: bool = False,
|
||||
text_seq_len: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
|
||||
# concatenated
|
||||
if encoder_hidden_states is not None:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
mod_shift, mod_scale, mod_gate = temb_mod_params
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
|
||||
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + mod_gate * attn_output
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
if split_hidden_states:
|
||||
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
||||
return encoder_hidden_states, hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 3.0,
|
||||
eps: float = 1e-6,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
|
||||
self.attn = Flux2Attention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=bias,
|
||||
added_proj_bias=bias,
|
||||
out_bias=bias,
|
||||
eps=eps,
|
||||
processor=Flux2AttnProcessor(),
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
||||
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Modulation parameters shape: [1, 1, self.dim]
|
||||
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
|
||||
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
|
||||
|
||||
# Img stream
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
|
||||
|
||||
# Conditioning txt stream
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
|
||||
|
||||
# Attention on concatenated img + txt stream
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the image stream (`hidden_states`).
|
||||
attn_output = gate_msa * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
# Process attention outputs for the text stream (`encoder_hidden_states`).
|
||||
context_attn_output = c_gate_msa * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class Flux2PosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
# Expected ids shape: [S, len(self.axes_dim)]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
is_npu = ids.device.type == "npu"
|
||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
|
||||
for i in range(len(self.axes_dim)):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[..., i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
|
||||
return time_guidance_emb
|
||||
|
||||
|
||||
class Flux2Modulation(nn.Module):
|
||||
def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
|
||||
super().__init__()
|
||||
self.mod_param_sets = mod_param_sets
|
||||
|
||||
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
||||
mod = self.act_fn(temb)
|
||||
mod = self.linear(mod)
|
||||
|
||||
if mod.ndim == 2:
|
||||
mod = mod.unsqueeze(1)
|
||||
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
|
||||
# Return tuple of 3-tuples of modulation params shift/scale/gate
|
||||
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
|
||||
|
||||
|
||||
class Flux2Transformer2DModel(
|
||||
ModelMixin,
|
||||
ConfigMixin,
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
FluxTransformer2DLoadersMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux 2.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `1`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `128`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `8`):
|
||||
The number of layers of dual stream DiT blocks to use.
|
||||
num_single_layers (`int`, defaults to `48`):
|
||||
The number of layers of single stream DiT blocks to use.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of dimensions to use for each attention head.
|
||||
num_attention_heads (`int`, defaults to `48`):
|
||||
The number of attention heads to use.
|
||||
joint_attention_dim (`int`, defaults to `15360`):
|
||||
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
||||
`encoder_hidden_states`).
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The number of dimensions to use for the pooled projection.
|
||||
guidance_embeds (`bool`, defaults to `True`):
|
||||
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
||||
axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`):
|
||||
The dimensions to use for the rotary positional embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 128,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 8,
|
||||
num_single_layers: int = 48,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 48,
|
||||
joint_attention_dim: int = 15360,
|
||||
timestep_guidance_channels: int = 256,
|
||||
mlp_ratio: float = 3.0,
|
||||
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
||||
rope_theta: int = 2000,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Sinusoidal positional embedding for RoPE on image and text tokens
|
||||
self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
# 2. Combined timestep + guidance embedding
|
||||
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
||||
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
||||
)
|
||||
|
||||
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
||||
# Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
|
||||
self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
||||
self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
||||
# Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
|
||||
self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
|
||||
|
||||
# 4. Input projections
|
||||
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
|
||||
|
||||
# 5. Double Stream Transformer Blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
Flux2TransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
eps=eps,
|
||||
bias=False,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 6. Single Stream Transformer Blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
Flux2SingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
eps=eps,
|
||||
bias=False,
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 7. Output layers
|
||||
self.norm_out = AdaLayerNormContinuous(
|
||||
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
|
||||
)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1]
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = self.time_guidance_embed(timestep, guidance)
|
||||
|
||||
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
||||
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)[0]
|
||||
|
||||
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# 3. Calculate RoPE embeddings from image and text tokens
|
||||
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
||||
# text prompts of differents lengths. Is this a use case we want to support?
|
||||
if img_ids.ndim == 3:
|
||||
img_ids = img_ids[0]
|
||||
if txt_ids.ndim == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if is_torch_npu_available():
|
||||
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
||||
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
||||
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
||||
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
concat_rotary_emb = (
|
||||
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
)
|
||||
|
||||
# 4. Double Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
double_stream_mod_img,
|
||||
double_stream_mod_txt,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb_mod_params_img=double_stream_mod_img,
|
||||
temb_mod_params_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
# Concatenate text and image streams for single-block inference
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 5. Single Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
single_stream_mod,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
temb_mod_params=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
# Remove text tokens from concatenated stream
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 6. Output layers
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
647
src/diffusers/models/transformers/transformer_z_image.py
Normal file
647
src/diffusers/models/transformers/transformer_z_image.py
Normal file
@@ -0,0 +1,647 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import RMSNorm
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
if mid_size is None:
|
||||
mid_size = out_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size,
|
||||
mid_size,
|
||||
bias=True,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
mid_size,
|
||||
out_size,
|
||||
bias=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||
return t_emb
|
||||
|
||||
|
||||
class ZSingleStreamAttnProcessor:
|
||||
"""
|
||||
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
|
||||
original Z-ImageAttention module.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
# Apply Norms
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
output = attn.to_out[0](hidden_states)
|
||||
if len(attn.to_out) > 1: # dropout
|
||||
output = attn.to_out[1](output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# Refactored to use diffusers Attention with custom processor
|
||||
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||
self.attention = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // n_heads,
|
||||
heads=n_heads,
|
||||
qk_norm="rms_norm" if qk_norm else None,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=ZSingleStreamAttnProcessor(),
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x) * scale_mlp,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class RopeEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 256.0,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (64, 128, 128),
|
||||
):
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||
self.freqs_cis = None
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||
with torch.device("cpu"):
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
assert ids.ndim == 2
|
||||
assert ids.shape[-1] == len(self.axes_dims)
|
||||
device = ids.device
|
||||
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ZImageTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=3840,
|
||||
n_layers=30,
|
||||
n_refiner_layers=2,
|
||||
n_heads=30,
|
||||
n_kv_heads=30,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2560,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
axes_lens=[1024, 512, 512],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.all_patch_size = all_patch_size
|
||||
self.all_f_patch_size = all_f_patch_size
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.t_scale = t_scale
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
assert len(all_patch_size) == len(all_f_patch_size)
|
||||
|
||||
all_x_embedder = {}
|
||||
all_final_layer = {}
|
||||
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
|
||||
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
|
||||
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
|
||||
|
||||
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
self.all_final_layer = nn.ModuleDict(all_final_layer)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
1000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps),
|
||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||
)
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
head_dim = dim // n_heads
|
||||
assert head_dim == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
def patchify_and_embed(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
device = all_image[0].device
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
all_cap_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat(
|
||||
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padding_pos_ids = (
|
||||
self.create_coordinate_grid(
|
||||
size=(1, 1, 1),
|
||||
start=(0, 0, 0),
|
||||
device=device,
|
||||
)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1)
|
||||
)
|
||||
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
||||
all_image_pos_ids.append(image_padded_pos_ids)
|
||||
# pad mask
|
||||
all_image_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
|
||||
adaln_input = t
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.noise_refiner:
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
else:
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.layers:
|
||||
unified = self._gradient_checkpointing_func(
|
||||
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
else:
|
||||
for layer in self.layers:
|
||||
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
|
||||
return x, {}
|
||||
@@ -610,6 +610,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# for edit, image size can be different from the target size (height/width)
|
||||
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(
|
||||
@@ -639,37 +640,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
|
||||
*[
|
||||
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
|
||||
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
|
||||
],
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
## ControlNet inputs for denoiser
|
||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -330,7 +330,7 @@ class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
|
||||
output_name: str = "resized_image",
|
||||
vae_image_output_name: str = "vae_image",
|
||||
):
|
||||
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.
|
||||
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
|
||||
|
||||
This block resizes an input image or a list input images and exposes the resized result under configurable
|
||||
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
|
||||
@@ -809,7 +809,9 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam(name="processed_image")]
|
||||
return [
|
||||
OutputParam(name="processed_image"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
@@ -849,10 +851,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __init__(self):
|
||||
self.vae_image_size = 1024 * 1024
|
||||
super().__init__()
|
||||
vae_image_size = 1024 * 1024
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -869,7 +868,6 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
||||
if block_state.vae_image is None and block_state.image is None:
|
||||
raise ValueError("`vae_image` and `image` cannot be None at the same time")
|
||||
|
||||
vae_image_sizes = None
|
||||
if block_state.vae_image is None:
|
||||
image = block_state.image
|
||||
self.check_inputs(
|
||||
@@ -881,19 +879,12 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
||||
image=image, height=height, width=width
|
||||
)
|
||||
else:
|
||||
# QwenImage Edit Plus can allow multiple input images with varied resolutions
|
||||
processed_images = []
|
||||
vae_image_sizes = []
|
||||
for img in block_state.vae_image:
|
||||
width, height = img.size
|
||||
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
|
||||
vae_image_sizes.append((vae_width, vae_height))
|
||||
processed_images.append(
|
||||
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
|
||||
)
|
||||
block_state.processed_image = processed_images
|
||||
width, height = block_state.vae_image[0].size
|
||||
image = block_state.vae_image
|
||||
|
||||
block_state.vae_image_sizes = vae_image_sizes
|
||||
block_state.processed_image = components.image_processor.preprocess(
|
||||
image=image, height=height, width=width
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -935,12 +926,17 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
]
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
|
||||
inputs = [
|
||||
InputParam(self._image_input_name, required=True),
|
||||
InputParam("generator"),
|
||||
]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
@@ -978,50 +974,6 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
# Each reference image latent can have varied resolutions hence we return this as a list.
|
||||
return [
|
||||
OutputParam(
|
||||
self._image_latents_output_name,
|
||||
type_hint=List[torch.Tensor],
|
||||
description="The latents representing the reference image(s).",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image = getattr(block_state, self._image_input_name)
|
||||
|
||||
# Encode image into latents
|
||||
image_latents = []
|
||||
for img in image:
|
||||
image_latents.append(
|
||||
encode_vae_image(
|
||||
image=img,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
)
|
||||
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
|
||||
@@ -224,7 +224,11 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
@@ -368,76 +372,6 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# Each image latent can have different size in QwenImage Edit Plus.
|
||||
image_heights = []
|
||||
image_widths = []
|
||||
packed_image_latent_tensors = []
|
||||
|
||||
for img_latent_tensor in image_latent_tensor:
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
|
||||
image_heights.append(height)
|
||||
image_widths.append(width)
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
img_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=img_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
packed_image_latent_tensors.append(img_latent_tensor)
|
||||
|
||||
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
|
||||
block_state.image_height = image_heights
|
||||
block_state.image_width = image_widths
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
block_state.height = block_state.height or image_heights[-1]
|
||||
block_state.width = block_state.width or image_widths[-1]
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditPlusRoPEInputsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
@@ -41,7 +40,6 @@ from .encoders import (
|
||||
QwenImageEditPlusProcessImagesInputStep,
|
||||
QwenImageEditPlusResizeDynamicStep,
|
||||
QwenImageEditPlusTextEncoderStep,
|
||||
QwenImageEditPlusVaeEncoderDynamicStep,
|
||||
QwenImageEditResizeDynamicStep,
|
||||
QwenImageEditTextEncoderStep,
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
@@ -49,12 +47,7 @@ from .encoders import (
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageControlNetInputsStep,
|
||||
QwenImageEditPlusInputsDynamicStep,
|
||||
QwenImageInputsDynamicStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -911,13 +904,13 @@ QwenImageEditPlusVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
|
||||
("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
|
||||
("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
|
||||
block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
|
||||
|
||||
@@ -926,62 +919,25 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
#### QwenImage Edit Plus input blocks
|
||||
QwenImageEditPlusInputBlocks = InsertableDict(
|
||||
[
|
||||
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||
(
|
||||
"additional_inputs",
|
||||
QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = QwenImageEditPlusInputBlocks.values()
|
||||
block_names = QwenImageEditPlusInputBlocks.keys()
|
||||
|
||||
|
||||
#### QwenImage Edit Plus presets
|
||||
EDIT_PLUS_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditPlusVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
|
||||
("input", QwenImageEditPlusInputStep()),
|
||||
("input", QwenImageEditInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
|
||||
|
||||
|
||||
# auto before_denoise step for edit tasks
|
||||
class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [QwenImageEditPlusBeforeDenoiseStep]
|
||||
block_classes = [QwenImageEditBeforeDenoiseStep]
|
||||
block_names = ["edit"]
|
||||
block_trigger_inputs = ["image_latents"]
|
||||
|
||||
@@ -990,7 +946,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for edit (img2img) task.\n"
|
||||
+ " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
|
||||
+ " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
|
||||
+ " - if `image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
@@ -999,7 +955,9 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditPlusVaeEncoderStep]
|
||||
block_classes = [
|
||||
QwenImageEditPlusVaeEncoderStep,
|
||||
]
|
||||
block_names = ["edit"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@@ -1016,25 +974,10 @@ class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
## 3.3 QwenImage-Edit/auto blocks & presets
|
||||
|
||||
|
||||
class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditPlusInputStep]
|
||||
block_names = ["edit"]
|
||||
block_trigger_inputs = ["image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit denoising step.\n"
|
||||
+ " It is an auto pipeline block that works for edit task.\n"
|
||||
+ " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n"
|
||||
+ " - if `image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusAutoInputStep,
|
||||
QwenImageEditAutoInputStep,
|
||||
QwenImageEditPlusAutoBeforeDenoiseStep,
|
||||
QwenImageEditAutoDenoiseStep,
|
||||
]
|
||||
|
||||
@@ -129,6 +129,7 @@ else:
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -395,6 +396,7 @@ else:
|
||||
"WanVACEPipeline",
|
||||
"WanAnimatePipeline",
|
||||
]
|
||||
_import_structure["z_image"] = ["ZImagePipeline"]
|
||||
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
@@ -654,6 +656,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2Pipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
from .hunyuan_video import (
|
||||
@@ -824,6 +827,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
from .z_image import ZImagePipeline
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
|
||||
@@ -73,7 +73,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class BriaFiboPipeline(DiffusionPipeline):
|
||||
class BriaFiboPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
r"""
|
||||
Args:
|
||||
transformer (`BriaFiboTransformer2DModel`):
|
||||
|
||||
47
src/diffusers/pipelines/flux2/__init__.py
Normal file
47
src/diffusers/pipelines/flux2/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["Flux2PipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flux2 import Flux2Pipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
138
src/diffusers/pipelines/flux2/image_processor.py
Normal file
138
src/diffusers/pipelines/flux2/image_processor.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import PIL.Image
|
||||
|
||||
from ...configuration_utils import register_to_config
|
||||
from ...image_processor import VaeImageProcessor
|
||||
|
||||
|
||||
class Flux2ImageProcessor(VaeImageProcessor):
|
||||
r"""
|
||||
Image processor to preprocess the reference (character) image for the Flux2 model.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
||||
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
||||
vae_scale_factor (`int`, *optional*, defaults to `16`):
|
||||
VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
|
||||
this factor.
|
||||
vae_latent_channels (`int`, *optional*, defaults to `32`):
|
||||
VAE latent channels.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image to [-1,1].
|
||||
do_convert_rgb (`bool`, *optional*, defaults to be `True`):
|
||||
Whether to convert the images to RGB format.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
vae_scale_factor: int = 16,
|
||||
vae_latent_channels: int = 32,
|
||||
do_normalize: bool = True,
|
||||
do_convert_rgb: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
do_resize=do_resize,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
vae_latent_channels=vae_latent_channels,
|
||||
do_normalize=do_normalize,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_image_input(
|
||||
image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Check if image meets minimum size and aspect ratio requirements.
|
||||
|
||||
Args:
|
||||
image: PIL Image to validate
|
||||
max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width)
|
||||
min_side_length: Minimum pixels required for width and height
|
||||
max_area: Maximum allowed area in pixels²
|
||||
|
||||
Returns:
|
||||
The input image if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If image is too small or aspect ratio is too extreme
|
||||
"""
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
|
||||
|
||||
width, height = image.size
|
||||
|
||||
# Check minimum dimensions
|
||||
if width < min_side_length or height < min_side_length:
|
||||
raise ValueError(
|
||||
f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px"
|
||||
)
|
||||
|
||||
# Check aspect ratio
|
||||
aspect_ratio = max(width / height, height / width)
|
||||
if aspect_ratio > max_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). "
|
||||
f"Maximum allowed ratio is {max_aspect_ratio}:1"
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
|
||||
image_width, image_height = image.size
|
||||
|
||||
scale = math.sqrt(target_area / (image_width * image_height))
|
||||
width = int(image_width * scale)
|
||||
height = int(image_height * scale)
|
||||
|
||||
return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
||||
|
||||
def _resize_and_crop(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
r"""
|
||||
center crop the image to the specified width and height.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image to resize and crop.
|
||||
width (`int`):
|
||||
The width to resize the image to.
|
||||
height (`int`):
|
||||
The height to resize the image to.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The resized and cropped image.
|
||||
"""
|
||||
image_width, image_height = image.size
|
||||
|
||||
left = (image_width - width) // 2
|
||||
top = (image_height - height) // 2
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
return image.crop((left, top, right, bottom))
|
||||
883
src/diffusers/pipelines/flux2/pipeline_flux2.py
Normal file
883
src/diffusers/pipelines/flux2/pipeline_flux2.py
Normal file
@@ -0,0 +1,883 @@
|
||||
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import Flux2PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import Flux2Pipeline
|
||||
|
||||
>>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
|
||||
>>> image.save("flux.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def format_text_input(prompts: List[str], system_message: str = None):
|
||||
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
|
||||
# when truncation is enabled. The processor counts [IMG] tokens and fails
|
||||
# if the count changes after truncation.
|
||||
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
||||
|
||||
return [
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
for prompt in cleaned_txt
|
||||
]
|
||||
|
||||
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
r"""
|
||||
The Flux2 pipeline for text-to-image generation.
|
||||
|
||||
Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)
|
||||
|
||||
Args:
|
||||
transformer ([`Flux2Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLFlux2`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Mistral3ForConditionalGeneration`]):
|
||||
[Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
|
||||
tokenizer (`AutoProcessor`):
|
||||
Tokenizer of class
|
||||
[PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLFlux2,
|
||||
text_encoder: Mistral3ForConditionalGeneration,
|
||||
tokenizer: AutoProcessor,
|
||||
transformer: Flux2Transformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 512
|
||||
self.default_sample_size = 128
|
||||
|
||||
# fmt: off
|
||||
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
||||
# fmt: on
|
||||
|
||||
@staticmethod
|
||||
def _get_mistral_3_small_prompt_embeds(
|
||||
text_encoder: Mistral3ForConditionalGeneration,
|
||||
tokenizer: AutoProcessor,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
# fmt: off
|
||||
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
||||
# fmt: on
|
||||
hidden_states_layers: List[int] = (10, 20, 30),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
# Format input messages
|
||||
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
|
||||
|
||||
# Process all messages at once
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages_batch,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# Move to device
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
attention_mask = inputs["attention_mask"].to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
def _prepare_text_ids(
|
||||
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||
t_coord: Optional[torch.Tensor] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_ids(
|
||||
latents: torch.Tensor, # (B, C, H, W)
|
||||
):
|
||||
r"""
|
||||
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
Latent tensor of shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
||||
H=[0..H-1], W=[0..W-1], L=0
|
||||
"""
|
||||
|
||||
batch_size, _, height, width = latents.shape
|
||||
|
||||
t = torch.arange(1) # [0] - time dimension
|
||||
h = torch.arange(height)
|
||||
w = torch.arange(width)
|
||||
l = torch.arange(1) # [0] - layer dimension
|
||||
|
||||
# Create position IDs: (H*W, 4)
|
||||
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||
|
||||
# Expand to batch: (B, H*W, 4)
|
||||
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
return latent_ids
|
||||
|
||||
@staticmethod
|
||||
def _prepare_image_ids(
|
||||
image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
||||
scale: int = 10,
|
||||
):
|
||||
r"""
|
||||
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||
|
||||
This function creates a unique coordinate for every pixel/patch across all input latent with different
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
image_latents (List[torch.Tensor]):
|
||||
A list of image latent feature tensors, typically of shape (C, H, W).
|
||||
scale (int, optional):
|
||||
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
||||
latent is: 'scale + scale * i'. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
||||
input latents.
|
||||
|
||||
Coordinate Components (Dimension 4):
|
||||
- T (Time): The unique index indicating which latent image the coordinate belongs to.
|
||||
- H (Height): The row index within that latent image.
|
||||
- W (Width): The column index within that latent image.
|
||||
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
||||
"""
|
||||
|
||||
if not isinstance(image_latents, list):
|
||||
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||
|
||||
# create time offset for each reference image
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
@staticmethod
|
||||
def _patchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpatchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents):
|
||||
"""
|
||||
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
||||
"""
|
||||
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
x_list = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
||||
|
||||
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||
x_list.append(out)
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_mistral_3_small_prompt_embeds(
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
system_message=self.system_message,
|
||||
hidden_states_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
text_ids = self._prepare_text_ids(prompt_embeds)
|
||||
text_ids = text_ids.to(device)
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
image_latents = self._patchify_latents(image_latents)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
||||
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_latents_channels,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator: torch.Generator,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latent_ids = self._prepare_latent_ids(latents)
|
||||
latent_ids = latent_ids.to(device)
|
||||
|
||||
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
||||
return latents, latent_ids
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
batch_size,
|
||||
generator: torch.Generator,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
||||
image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
||||
|
||||
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||
|
||||
# Pack each latent and concatenate
|
||||
packed_latents = []
|
||||
for latent in image_latents:
|
||||
# latent: (1, 128, 32, 32)
|
||||
packed = self._pack_latents(latent) # (1, 1024, 128)
|
||||
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
||||
packed_latents.append(packed)
|
||||
|
||||
# Concatenate all reference tokens along sequence dimension
|
||||
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
||||
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
||||
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.to(device)
|
||||
|
||||
return image_latents, image_latent_ids
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * 2) != 0
|
||||
or width is not None
|
||||
and width % (self.vae_scale_factor * 2) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: Optional[float] = 4.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
guidance_scale (`float`, *optional*, defaults to 1.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
text_encoder_out_layers (`Tuple[int]`):
|
||||
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare text embeddings
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
text_encoder_out_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
# 4. process images
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
condition_images = None
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
|
||||
condition_images = []
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 5. prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_ids = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_latents_channels=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
image_latents = None
|
||||
image_latent_ids = None
|
||||
if condition_images is not None:
|
||||
image_latents, image_latent_ids = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
|
||||
# 6. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
|
||||
# 7. Denoising loop
|
||||
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
||||
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
latent_image_ids = latent_ids
|
||||
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
|
||||
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input, # (B, image_seq_len, C)
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids, # B, text_seq_len, 4
|
||||
img_ids=latent_image_ids, # B, image_seq_len, 4
|
||||
joint_attention_kwargs=self._attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1) :]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
torch.save({"pred": latents}, "pred_d.pt")
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return Flux2PipelineOutput(images=image)
|
||||
23
src/diffusers/pipelines/flux2/pipeline_output.py
Normal file
23
src/diffusers/pipelines/flux2/pipeline_output.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class Flux2PipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Flux2 image generation pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
|
||||
height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
|
||||
pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
|
||||
passed to the decoder.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
50
src/diffusers/pipelines/z_image/__init__.py
Normal file
50
src/diffusers/pipelines/z_image/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa: F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_output"] = ["ZImagePipelineOutput"]
|
||||
_import_structure["pipeline_z_image"] = ["ZImagePipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
from .pipeline_z_image import ZImagePipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
35
src/diffusers/pipelines/z_image/pipeline_output.py
Normal file
35
src/diffusers/pipelines/z_image/pipeline_output.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Z-Image pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
607
src/diffusers/pipelines/z_image/pipeline_z_image.py
Normal file
607
src/diffusers/pipelines/z_image/pipeline_z_image.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImagePipeline
|
||||
|
||||
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||
>>> # (1) Use flash attention 2
|
||||
>>> # pipe.transformer.set_attention_backend("flash")
|
||||
>>> # (2) Use flash attention 3
|
||||
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... height=1024,
|
||||
... width=1024,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(42),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
assert num_images_per_prompt == 1
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
height = height or 1024
|
||||
width = width or 1024
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
assert self.dtype == torch.bfloat16
|
||||
dtype = self.dtype
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
model_out_list = self.transformer(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:batch_size]
|
||||
neg_out = model_out_list[batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
latents = latents.to(dtype)
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -488,9 +488,20 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
# Copied from diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
in_sigmas (`torch.Tensor`):
|
||||
The input sigma values to be converted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The converted sigma values following the Karras noise schedule.
|
||||
"""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
@@ -99,15 +99,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
num_train_timesteps (`int`, defaults to `1000`):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
beta_start (`float`, defaults to `0.0001`):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
beta_end (`float`, defaults to `0.02`):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
@@ -118,14 +117,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
steps_offset (`int`, defaults to `0`):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
"""
|
||||
|
||||
@@ -138,13 +137,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
@@ -183,7 +182,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
def init_noise_sigma(self) -> Union[float, torch.Tensor]:
|
||||
"""
|
||||
The standard deviation of the initial noise distribution.
|
||||
|
||||
Returns:
|
||||
`float` or `torch.Tensor`:
|
||||
The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
|
||||
the timestep spacing configuration.
|
||||
"""
|
||||
# standard deviation of the initial noise distribution
|
||||
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||
return self.sigmas.max()
|
||||
@@ -191,21 +198,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
def step_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
The index counter for current timestep. It will increase by 1 after each scheduler step.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The current step index, or `None` if not initialized.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
def begin_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The begin index for the scheduler, or `None` if not set.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -239,14 +254,21 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def get_lms_coefficient(self, order, t, current_order):
|
||||
def get_lms_coefficient(self, order: int, t: int, current_order: int) -> float:
|
||||
"""
|
||||
Compute the linear multistep coefficient.
|
||||
|
||||
Args:
|
||||
order ():
|
||||
t ():
|
||||
current_order ():
|
||||
order (`int`):
|
||||
The order of the linear multistep method.
|
||||
t (`int`):
|
||||
The current timestep index.
|
||||
current_order (`int`):
|
||||
The current order for which to compute the coefficient.
|
||||
|
||||
Returns:
|
||||
`float`:
|
||||
The computed linear multistep coefficient.
|
||||
"""
|
||||
|
||||
def lms_derivative(tau):
|
||||
@@ -261,7 +283,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -367,7 +389,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index = self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -403,9 +425,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
in_sigmas (`torch.Tensor`):
|
||||
The input sigma values to be converted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The converted sigma values following the Karras noise schedule.
|
||||
"""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
@@ -629,5 +661,5 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -408,6 +408,21 @@ class AutoencoderKLCosmos(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLFlux2(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImage(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -843,6 +858,21 @@ class EasyAnimateTransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Flux2Transformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FluxControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -827,6 +827,21 @@ class EasyAnimatePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxControlImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -3645,3 +3660,18 @@ class WuerstchenPriorPipeline(metaclass=DummyObject):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
168
tests/lora/test_lora_layers_flux2.py
Normal file
168
tests/lora/test_lora_layers_flux2.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
|
||||
|
||||
from ..testing_utils import floats_tensor, require_peft_backend, torch_device
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = Flux2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 16,
|
||||
"timestep_guidance_channels": 256,
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
transformer_cls = Flux2Transformer2DModel
|
||||
vae_kwargs = {
|
||||
"sample_size": 32,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ("DownEncoderBlock2D",),
|
||||
"up_block_types": ("UpDecoderBlock2D",),
|
||||
"block_out_channels": (4,),
|
||||
"layers_per_block": 1,
|
||||
"latent_channels": 1,
|
||||
"norm_num_groups": 1,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
}
|
||||
vae_cls = AutoencoderKLFlux2
|
||||
|
||||
tokenizer_cls, tokenizer_id = AutoProcessor, "hf-internal-testing/tiny-mistral3-diffusers"
|
||||
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
|
||||
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 10
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "a dog is dancing",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 8,
|
||||
"output_type": "np",
|
||||
"text_encoder_out_layers": (1,),
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
# Overriding because (1) text encoder LoRAs are not supported in Flux 2 and (2) because the Flux 2 single block
|
||||
# QKV projections are always fused, it has no `to_q` param as expected by the original test.
|
||||
def test_lora_fuse_nan(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
possible_tower_names = ["transformer_blocks", "single_transformer_blocks"]
|
||||
filtered_tower_names = [
|
||||
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
|
||||
]
|
||||
if len(filtered_tower_names) == 0:
|
||||
reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
|
||||
raise ValueError(reason)
|
||||
for tower_name in filtered_tower_names:
|
||||
transformer_tower = getattr(pipe.transformer, tower_name)
|
||||
is_single = "single" in tower_name
|
||||
if is_single:
|
||||
transformer_tower[0].attn.to_qkv_mlp_proj.lora_A["adapter-1"].weight += float("inf")
|
||||
else:
|
||||
transformer_tower[0].attn.to_k.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
162
tests/models/transformers/test_models_transformer_flux2.py
Normal file
162
tests/models/transformers/test_models_transformer_flux2.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import Flux2Transformer2DModel, attention_backend
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) # [height * width, 4]
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
# TODO (Daniel, Sayak): We can remove this test.
|
||||
def test_flux2_consistency(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
model = self.model_class(**init_dict)
|
||||
# state_dict = model.state_dict()
|
||||
# for key, param in state_dict.items():
|
||||
# print(f"{key} | {param.shape}")
|
||||
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with attention_backend("native"):
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
# Check against expected slice
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
|
||||
# fmt: on
|
||||
|
||||
flat_output = output.cpu().flatten()
|
||||
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Flux2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
|
||||
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
@@ -26,7 +26,6 @@ from diffusers.modular_pipelines import (
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
|
||||
|
||||
|
||||
@@ -105,16 +104,6 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
||||
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
|
||||
return inputs
|
||||
|
||||
def test_multi_images_as_input(self):
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = inputs.pop("image")
|
||||
inputs["image"] = [image, image]
|
||||
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
_ = pipe(
|
||||
**inputs,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
|
||||
def test_num_images_per_prompt(self):
|
||||
super().test_num_images_per_prompt()
|
||||
@@ -128,4 +117,4 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
||||
super().test_inference_batch_single_identical()
|
||||
|
||||
def test_guider_cfg(self):
|
||||
super().test_guider_cfg(1e-6)
|
||||
super().test_guider_cfg(1e-3)
|
||||
|
||||
@@ -34,7 +34,10 @@ from diffusers.utils import is_torch_version # noqa: E402
|
||||
|
||||
# fmt: off
|
||||
FORWARD_CASES = [
|
||||
("flash_hub", None),
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
||||
@@ -54,7 +57,11 @@ FORWARD_CASES = [
|
||||
]
|
||||
|
||||
COMPILE_CASES = [
|
||||
("flash_hub", None, True),
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
|
||||
0
tests/pipelines/flux2/__init__.py
Normal file
0
tests/pipelines/flux2/__init__.py
Normal file
190
tests/pipelines/flux2/test_pipeline_flux2.py
Normal file
190
tests/pipelines/flux2/test_pipeline_flux2.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3Config, Mistral3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLFlux2,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Flux2Pipeline,
|
||||
Flux2Transformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fused_layers_exist,
|
||||
)
|
||||
|
||||
|
||||
class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Flux2Pipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = Flux2Transformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=16,
|
||||
timestep_guidance_channels=256, # Hardcoded in original code
|
||||
axes_dims_rope=[4, 4, 4, 4],
|
||||
)
|
||||
|
||||
config = Mistral3Config(
|
||||
text_config={
|
||||
"model_type": "mistral",
|
||||
"vocab_size": 32000,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 37,
|
||||
"max_position_embeddings": 512,
|
||||
"num_attention_heads": 4,
|
||||
"num_hidden_layers": 1,
|
||||
"num_key_value_heads": 2,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_theta": 1000000000.0,
|
||||
"sliding_window": None,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 3,
|
||||
"pad_token_id": 4,
|
||||
},
|
||||
vision_config={
|
||||
"model_type": "pixtral",
|
||||
"hidden_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"image_size": 30,
|
||||
"patch_size": 6,
|
||||
"num_channels": 3,
|
||||
},
|
||||
bos_token_id=2,
|
||||
eos_token_id=3,
|
||||
pad_token_id=4,
|
||||
model_dtype="mistral3",
|
||||
image_seq_length=4,
|
||||
vision_feature_layer=-1,
|
||||
image_token_index=1,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = Mistral3ForConditionalGeneration(config)
|
||||
tokenizer = AutoProcessor.from_pretrained(
|
||||
"hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLFlux2(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a dog is dancing",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 8,
|
||||
"output_type": "np",
|
||||
"text_encoder_out_layers": (1,),
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
)
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
@@ -103,7 +103,7 @@ def check_qkv_fusion_processors_exist(model):
|
||||
def check_qkv_fused_layers_exist(model, layer_names):
|
||||
is_fused_submodules = []
|
||||
for submodule in model.modules():
|
||||
if not isinstance(submodule, AttentionModuleMixin):
|
||||
if not isinstance(submodule, AttentionModuleMixin) or not submodule._supports_qkv_fusion:
|
||||
continue
|
||||
is_fused_attribute_set = submodule.fused_projections
|
||||
is_fused_layer = True
|
||||
|
||||
0
tests/pipelines/z_image/__init__.py
Normal file
0
tests/pipelines/z_image/__init__.py
Normal file
Reference in New Issue
Block a user