Compare commits

...

117 Commits

Author SHA1 Message Date
DN6
d267bb6955 update 2025-06-14 01:20:39 +05:30
DN6
e10f701537 update 2025-06-14 00:51:24 +05:30
DN6
0497faa3db update 2025-06-14 00:33:39 +05:30
DN6
4f00bae5de update 2025-06-14 00:31:33 +05:30
DN6
a967e66d03 update 2025-06-14 00:28:56 +05:30
DN6
2b559e9b79 Merge branch 'chroma-fork' into chroma-final 2025-06-14 00:27:54 +05:30
DN6
589e939e33 Revert "fix equal size list input"
This reverts commit 3fe4ad67d5.
2025-06-14 00:17:17 +05:30
BuildTools
c711e8f10b fix equal size list input 2025-06-14 00:17:17 +05:30
BuildTools
0978b609c8 fix tests 2025-06-14 00:17:17 +05:30
BuildTools
4e24f26d6f default proj dim 2025-06-14 00:17:17 +05:30
BuildTools
8694f2ce53 add encoder test, remove pooled dim 2025-06-14 00:17:17 +05:30
BuildTools
fd3e94450a push local changes, fix docs 2025-06-14 00:17:17 +05:30
Dhruv Nair
41751a3ec0 update 2025-06-13 20:41:49 +02:00
BuildTools
3fe4ad67d5 fix equal size list input 2025-06-13 10:51:31 -06:00
BuildTools
49a4c8bc22 fix tests 2025-06-13 09:41:44 -06:00
BuildTools
06fb9957a7 default proj dim 2025-06-13 08:38:02 -06:00
BuildTools
16b6e33916 add encoder test, remove pooled dim 2025-06-13 08:11:12 -06:00
BuildTools
178c4ec928 push local changes, fix docs 2025-06-13 07:46:29 -06:00
DN6
292469d755 update 2025-06-13 18:43:26 +05:30
DN6
bf56c953b8 Merge branch 'chroma-fork' into chroma-final 2025-06-13 18:41:56 +05:30
BuildTools
b85229e262 Fix all pipeline test 2025-06-13 07:07:05 -06:00
DN6
f1be3ebc98 Merge branch 'chroma-fork' into chroma-final 2025-06-13 18:30:13 +05:30
Dhruv Nair
6735507705 fix for tests 2025-06-13 14:51:42 +02:00
BuildTools
de9a07fc20 fix test skipping again 2025-06-13 05:47:41 -06:00
BuildTools
2b6722ecea fix test skipping 2025-06-13 05:35:58 -06:00
BuildTools
00ebba9725 skip batch tests 2025-06-13 05:25:11 -06:00
BuildTools
bea8b0d86e make style, make quality 2025-06-13 04:54:33 -06:00
BuildTools
28dea06b3d fix docs 2025-06-13 04:53:30 -06:00
BuildTools
60e41b7835 Merge remote-tracking branch 'origin/chroma' into chroma 2025-06-13 04:43:48 -06:00
BuildTools
876649336e Make most transformer tests work 2025-06-13 04:43:31 -06:00
Edna
272685c0e5 Merge branch 'main' into chroma 2025-06-13 04:38:42 -06:00
BuildTools
829c6f199e Make more pipeline tests work 2025-06-13 04:38:13 -06:00
Dhruv Nair
89faa71f04 fix batch inference 2025-06-13 12:31:11 +02:00
DN6
926dcc6319 update to pad tokens 2025-06-13 13:43:17 +05:30
DN6
74fe45e823 update chroma transformer approximator init params 2025-06-13 13:36:39 +05:30
DN6
35dc65b7da update chroma transformer params 2025-06-13 13:30:04 +05:30
DN6
f35ec17a83 Merge remote-tracking branch '11698/chroma' into chroma-final 2025-06-13 11:28:57 +05:30
BuildTools
381e64b966 revert style fix 2025-06-12 22:22:39 -06:00
BuildTools
c330f08fa2 make fix-copes 2025-06-12 21:53:55 -06:00
BuildTools
523150fb2c fix import 2025-06-12 21:47:35 -06:00
BuildTools
2bc51c8387 try to fix import 2025-06-12 21:36:09 -06:00
BuildTools
fd36924620 remove # Copied from on protected members 2025-06-12 21:20:32 -06:00
Edna
e97a4dd0c7 fix # Copied from error 2025-06-12 21:13:12 -06:00
Edna
ad01d636be Merge branch 'main' into chroma 2025-06-12 21:06:55 -06:00
BuildTools
68b9cce897 switch to new input ids 2025-06-12 21:06:43 -06:00
github-actions[bot]
f49b149c1c Apply style fixes 2025-06-13 02:02:25 +00:00
DN6
19733af2fc make style 2025-06-13 07:22:45 +05:30
BuildTools
c85e46bd42 Fix auto pipeline + make style, quality 2025-06-12 10:31:02 -06:00
BuildTools
d31cf81566 Move Approximator and Embeddings 2025-06-12 10:20:27 -06:00
Edna
2347d53f90 Update src/diffusers/models/transformers/transformer_chroma.py
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-06-12 10:12:27 -06:00
Edna
cfd5b34051 fix chroma pipeline fast tests 2025-06-12 03:49:39 -06:00
Edna
c8d6aef936 chroma init 2025-06-12 03:47:24 -06:00
Edna
f8d4a1a774 move chroma test (oops) 2025-06-12 03:46:28 -06:00
Edna
15ca813e3e Add transformer tests 2025-06-12 03:45:43 -06:00
Edna
7235805e75 Revert cond + uncond batching 2025-06-12 03:40:52 -06:00
Edna
abf8a33a96 update norm imports 2025-06-12 03:33:23 -06:00
Edna
6a0db55af8 Update # Copied from statements 2025-06-12 03:27:35 -06:00
Edna
fe5af79a19 Add # Copied from for shift 2025-06-12 03:23:09 -06:00
Edna
bedb32087a (untested) batch cond and uncond 2025-06-12 03:18:33 -06:00
Edna
03fbd520f4 Add chroma fast tests 2025-06-12 03:11:48 -06:00
Edna
1442c9789a Remove pruned AdaLayerNorms 2025-06-12 03:05:10 -06:00
Edna
a1fac68a2d Move chroma layers into transformer 2025-06-12 03:04:41 -06:00
Edna
3e36a21c8e Update chroma.md 2025-06-12 02:58:21 -06:00
Edna
a93e64d6fb Update docs/source/en/api/models/chroma_transformer.md
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-06-12 02:57:28 -06:00
Edna
3f39b1a730 do_cfg -> self.do_classifier_free_guidance 2025-06-12 02:56:24 -06:00
Edna
18327cb57c Update docs/source/en/api/pipelines/chroma.md
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-06-12 02:52:39 -06:00
Edna
da846d1fff fix hf papers regression in more places 2025-06-12 00:53:40 -06:00
Edna
42c0e8ecbe undo arxiv change
unsure why that happened
2025-06-12 00:50:36 -06:00
Edna
0c5eb44701 undo don't change dtype 2025-06-12 00:46:41 -06:00
Edna
b0cf6803a7 initial chroma docs 2025-06-11 22:07:21 -06:00
Edna
f821f2ad5e add .md (oops) 2025-06-11 21:54:43 -06:00
Edna
619921ca22 add chroma autodoc 2025-06-11 21:53:27 -06:00
BuildTools
1efa772f69 remove unused stuff, fix up docs 2025-06-11 21:46:40 -06:00
Edna
3e2452ded0 dont change dtype 2025-06-11 21:23:35 -06:00
Edna
2d57f3dbac Merge branch 'main' into chroma 2025-06-11 21:20:24 -06:00
Edna
1bd8fdfcb6 don't return length 2025-06-11 20:56:27 -06:00
Edna
406ab3b1e9 remove guidance from embeddings 2025-06-11 20:47:59 -06:00
Edna
e31c94866d remove guidance embed (pipeline) 2025-06-11 20:46:59 -06:00
Edna
01bc0dcc56 remove guidance 2025-06-11 20:45:45 -06:00
Edna
e69d73099d use DN6 embeddings 2025-06-11 20:05:28 -06:00
Edna
442f77a2d7 use chroma pipeline output 2025-06-11 19:59:43 -06:00
Edna
ab7942174a use dn6 attn mask + fix true_cfg_scale 2025-06-11 19:57:31 -06:00
Edna
f6de1afc3f update 2025-06-11 19:54:27 -06:00
Edna
f783f38883 ensure correct dtype for chroma embeddings 2025-06-11 19:52:43 -06:00
Edna
a3b6697bc3 Merge branch 'main' into chroma 2025-06-11 19:48:02 -06:00
Edna
68f771bf43 take pooled projections out of transformer 2025-06-11 19:38:38 -06:00
Edna
df7fde7a6d fix load 2025-06-11 19:36:34 -06:00
Edna
77b429eda4 change to my own unpooled embeddeer 2025-06-11 19:35:10 -06:00
Edna
3309ffef1c remove pooled prompt embeds 2025-06-11 19:33:17 -06:00
Edna
146255aba1 no attn mask (can't get it to work) 2025-06-11 19:17:29 -06:00
Edna
c9b46af65f wrap attn mask 2025-06-11 19:16:24 -06:00
Edna
7c75d8e98d dont modify mask (for now) 2025-06-11 19:15:18 -06:00
Edna
38429ffcac remove mask function 2025-06-11 19:11:47 -06:00
Edna
f190c02af7 work on swapping text encoders 2025-06-11 19:09:37 -06:00
Edna
6c0aed14db remove prompt_2 2025-06-11 19:06:45 -06:00
Edna
0b027a2453 swap embedder location 2025-06-11 19:04:52 -06:00
Edna
2fcc75a6d8 take out variant from blocks 2025-06-11 18:55:56 -06:00
Edna
af918c89dd change to chroma transformer 2025-06-11 18:55:03 -06:00
Edna
7445cf422a add chroma to pipeline init 2025-06-11 18:53:06 -06:00
Edna
a6f231c7ce add chroma to auto pipeline 2025-06-11 18:51:45 -06:00
Edna
6441e70def update 2025-06-11 18:48:44 -06:00
Edna
f0c75b6b6f update 2025-06-11 18:46:51 -06:00
Edna
5eb4b822ae fix single file 2025-06-11 18:38:58 -06:00
Edna
4e698b1088 add chroma to init 2025-06-11 18:21:10 -06:00
Edna
c22930d7cc add chroma to init 2025-06-11 18:18:56 -06:00
Edna
7400278857 add chroma transformer to dummy tp 2025-06-11 18:16:44 -06:00
Edna
32659236b2 make chroma output class 2025-06-10 02:24:23 -06:00
Edna
c8cbb31614 add chroma init 2025-06-10 02:22:52 -06:00
Edna
b0df9691d2 get decently far in changing variant stuff 2025-06-10 02:09:52 -06:00
Edna
22ecd19f91 take out variant stuff 2025-06-09 21:32:52 -06:00
Edna
33ea0b65a4 add chroma to transformer init 2025-06-09 21:25:19 -06:00
Edna
bc36a0d883 add chroma to mappings 2025-06-09 21:15:19 -06:00
Edna
32e6a006cf add chroma loader 2025-06-09 21:13:32 -06:00
Edna
15f2bd5c39 working state (embeddings) 2025-06-09 21:05:59 -06:00
Edna
e271af9495 working state (normalization) 2025-06-09 21:03:10 -06:00
Edna
3c2865c534 working state form hameerabbasi and iddl (transformer) 2025-06-09 21:02:12 -06:00
Edna
ff0b9a3c4c working state from hameerabbasi and iddl 2025-06-09 20:59:00 -06:00
23 changed files with 2336 additions and 5 deletions

View File

@@ -283,6 +283,8 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/chroma_transformer
title: ChromaTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/cogview3plus_transformer2d
@@ -405,6 +407,8 @@
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP-Diffusion
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/cogview3

View File

@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ChromaTransformer2DModel
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
## ChromaTransformer2DModel
[[autodoc]] ChromaTransformer2DModel

View File

@@ -0,0 +1,71 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Chroma
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Chroma is a text to image generation model based on Flux.
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
<Tip>
Chroma can use all the same optimizations as Flux.
</Tip>
## Inference (Single File)
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
The following example demonstrates how to run Chroma from a single file.
Then run the following example
```python
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
from transformers import T5EncoderModel
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
guidance_scale=4.0,
output_type="pil",
num_inference_steps=26,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("image.png")
```
## ChromaPipeline
[[autodoc]] ChromaPipeline
- all
- __call__

View File

@@ -159,6 +159,7 @@ else:
"AutoencoderTiny",
"AutoModel",
"CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
@@ -352,6 +353,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"ChromaPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
@@ -768,6 +770,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny,
AutoModel,
CacheMixin,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
@@ -940,6 +943,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
ChromaPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,

View File

@@ -60,6 +60,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
}

View File

@@ -29,6 +29,7 @@ from .single_file_utils import (
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
@@ -97,6 +98,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ChromaTransformer2DModel": {
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",

View File

@@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
return checkpoint
def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
num_guidance_layers = (
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
)
mlp_ratio = 4.0
inner_dim = 3072
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
# guidance
converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.bias"
)
converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.weight"
)
converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.bias"
)
converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.weight"
)
for i in range(num_guidance_layers):
block_prefix = f"distilled_guidance_layer.layers.{i}."
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
)
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
)
converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
f"distilled_guidance_layer.norms.{i}.scale"
)
# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
# x_embedder
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict

View File

@@ -74,6 +74,7 @@ if is_torch_available():
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,

View File

@@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -1325,7 +1325,7 @@ class Timesteps(nn.Module):
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,

View File

@@ -17,6 +17,7 @@ if is_torch_available():
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_cosmos import CosmosTransformer3DModel

View File

@@ -0,0 +1,732 @@
# Copyright 2025 Black Forest Labs, The HuggingFace Team and loadstone-rock . All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ChromaAdaLayerNormZeroPruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__()
if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
else:
self.emb = None
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.flatten(1, 2).chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
shift_msa, scale_msa, gate_msa = emb.flatten(1, 2).chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class ChromaAdaLayerNormContinuousPruned(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
shift, scale = torch.chunk(emb.flatten(1, 2).to(x.dtype), 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class ChromaCombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, num_channels: int, out_dim: int):
super().__init__()
self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
self.register_buffer(
"mod_proj",
get_timestep_embedding(
torch.arange(out_dim) * 1000, 2 * num_channels, flip_sin_to_cos=True, downscale_freq_shift=0
),
persistent=False,
)
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
batch_size = timestep.shape[0]
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
guidance_proj = self.guidance_proj(torch.tensor([0] * batch_size)).to(
dtype=timestep.dtype, device=timestep.device
)
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device).repeat(batch_size, 1, 1)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
return input_vec.to(timestep.dtype)
class ChromaApproximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList(
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
)
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
return self.out_proj(x)
@maybe_allow_in_graph
class ChromaSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
)
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
@maybe_allow_in_graph
class ChromaTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
):
super().__init__()
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=FluxAttnProcessor2_0(),
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb_txt
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class ChromaTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
"""
The Transformer model introduced in Flux, modified for Chroma.
Reference: https://huggingface.co/lodestones/Chroma
Args:
patch_size (`int`, defaults to `1`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `64`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output. If not specified, it defaults to `in_channels`.
num_layers (`int`, defaults to `19`):
The number of layers of dual stream DiT blocks to use.
num_single_layers (`int`, defaults to `38`):
The number of layers of single stream DiT blocks to use.
attention_head_dim (`int`, defaults to `128`):
The number of dimensions to use for each attention head.
num_attention_heads (`int`, defaults to `24`):
The number of attention heads to use.
joint_attention_dim (`int`, defaults to `4096`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions to use for the rotary positional embeddings.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
approximator_num_channels: int = 64,
approximator_hidden_dim: int = 5120,
approximator_layers: int = 5,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
num_channels=approximator_num_channels // 4,
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
)
self.distilled_guidance_layer = ChromaApproximator(
in_dim=approximator_num_channels,
out_dim=self.inner_dim,
hidden_dim=approximator_hidden_dim,
n_layers=approximator_layers,
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
ChromaTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
ChromaSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_single_layers)
]
)
self.norm_out = ChromaAdaLayerNormContinuousPruned(
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
input_vec = self.time_text_embed(timestep)
pooled_temb = self.distilled_guidance_layer(input_vec)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
img_offset = 3 * len(self.single_transformer_blocks)
txt_offset = img_offset + 6 * len(self.transformer_blocks)
img_modulation = img_offset + 6 * index_block
text_modulation = txt_offset + 6 * index_block
temb = torch.cat(
(
pooled_temb[:, img_modulation : img_modulation + 6],
pooled_temb[:, text_modulation : text_modulation + 6],
),
dim=1,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
temb = pooled_temb[:, -2:]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -148,6 +148,7 @@ else:
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["chroma"] = ["ChromaPipeline"]
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
@@ -531,6 +532,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .chroma import ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,

View File

@@ -21,6 +21,7 @@ from ..configuration_utils import ConfigMixin
from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .chroma import ChromaPipeline
from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .controlnet import (
@@ -143,6 +144,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaPipeline),
("lumina2", Lumina2Pipeline),
("chroma", ChromaPipeline),
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("cogview4-control", CogView4ControlPipeline),

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["ChromaPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_chroma import ChromaPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,863 @@
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ChromaTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import ChromaPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import ChromaPipeline
>>> pipe = ChromaPipeline.from_single_file(
... "chroma-unlocked-v35-detail-calibrated.safetensors", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
>>> image.save("chroma.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ChromaPipeline(
DiffusionPipeline,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
FluxIPAdapterMixin,
):
r"""
The Chroma pipeline for text-to-image generation.
Reference: https://huggingface.co/lodestones/Chroma/
Args:
transformer ([`ChromaTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: ChromaTransformer2DModel,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 128
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
)[0]
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
do_classifier_free_guidance: bool = True,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
negative_text_ids = None
if do_classifier_free_guidance:
if negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = (
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)
for single_ip_adapter_image in ip_adapter_image:
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
image_embeds.append(single_image_embeds[None, :])
else:
if not isinstance(ip_adapter_image_embeds, list):
ip_adapter_image_embeds = [ip_adapter_image_embeds]
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)
for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for single_image_embeds in image_embeds:
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
not greater than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
text_ids,
negative_prompt_embeds,
negative_text_ids,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
do_classifier_free_guidance=self.do_classifier_free_guidance,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ChromaPipelineOutput(images=image)

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class ChromaPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -325,6 +325,21 @@ class CacheMixin(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ChromaTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CogVideoXTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -272,6 +272,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ChromaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CLIPImageProjection(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,183 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import ChromaTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
def create_chroma_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
}
@property
def input_shape(self):
return (16, 4)
@property
def output_shape(self):
return (16, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"axes_dims_rope": [4, 4, 8],
"approximator_num_channels": 8,
"approximator_hidden_dim": 16,
"approximator_layers": 1,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"ChromaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class ChromaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
class ChromaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()

View File

@@ -57,7 +57,9 @@ def create_flux_ip_adapter_state_dict(model):
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
image_embed_dim=(
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
),
num_image_text_embeds=4,
)

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,167 @@
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class ChromaPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
):
pipeline_class = ChromaPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = ChromaTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
axes_dims_rope=[4, 4, 8],
approximator_hidden_dim=32,
approximator_layers=1,
approximator_num_channels=16,
)
torch.manual_seed(0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
def test_chroma_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
"Fusion of QKV projections shouldn't affect the outputs."
)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
)
def test_chroma_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)

View File

@@ -521,7 +521,8 @@ class FluxIPAdapterTesterMixin:
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
inputs["negative_prompt"] = ""
inputs["true_cfg_scale"] = 4.0
if "true_cfg_scale" in inspect.signature(self.pipeline_class.__call__).parameters:
inputs["true_cfg_scale"] = 4.0
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
@@ -542,7 +543,11 @@ class FluxIPAdapterTesterMixin:
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_embed_dim = pipe.transformer.config.pooled_projection_dim
image_embed_dim = (
pipe.transformer.config.pooled_projection_dim
if hasattr(pipe.transformer.config, "pooled_projection_dim")
else 768
)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))