Compare commits

...

4 Commits

Author SHA1 Message Date
sayakpaul
14f5ef3c0e add pipeline tests for klein kv. 2026-03-13 09:24:24 +05:30
sayakpaul
6aacf9ef23 add Flux2KleinKV to docs. 2026-03-13 09:00:38 +05:30
sayakpaul
da43ca6879 add output class to docs. 2026-03-13 08:57:50 +05:30
sayakpaul
d646286b79 implement Flux2Transformer2DModelOutput. 2026-03-13 08:55:43 +05:30
5 changed files with 207 additions and 8 deletions

View File

@@ -17,3 +17,7 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-
## Flux2Transformer2DModel
[[autodoc]] Flux2Transformer2DModel
## Flux2Transformer2DModelOutput
[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput

View File

@@ -41,5 +41,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
## Flux2KleinPipeline
[[autodoc]] Flux2KleinPipeline
- all
- __call__
## Flux2KleinKVPipeline
[[autodoc]] Flux2KleinKVPipeline
- all
- __call__

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any
import torch
@@ -21,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import BaseOutput, apply_lora_scale, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
@@ -32,7 +33,6 @@ from ..embeddings import (
apply_rotary_emb,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
@@ -40,6 +40,22 @@ from ..normalization import AdaLayerNormContinuous
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Flux2Transformer2DModelOutput(BaseOutput):
"""
The output of [`Flux2Transformer2DModel`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on the `encoder_hidden_states` input.
kv_cache (`Flux2KVCache`, *optional*):
The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`.
"""
sample: "torch.Tensor" # noqa: F821
kv_cache: "Flux2KVCache | None" = None
class Flux2KVLayerCache:
"""Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
@@ -1174,7 +1190,7 @@ class Flux2Transformer2DModel(
kv_cache_mode: str | None = None,
num_ref_tokens: int = 0,
ref_fixed_timestep: float = 0.0,
) -> torch.Tensor | Transformer2DModelOutput:
) -> torch.Tensor | Flux2Transformer2DModelOutput:
"""
The [`Flux2Transformer2DModel`] forward method.
@@ -1356,10 +1372,10 @@ class Flux2Transformer2DModel(
if kv_cache_mode == "extract":
if not return_dict:
return (output,), kv_cache
return Transformer2DModelOutput(sample=output), kv_cache
return (output, kv_cache)
return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
return Flux2Transformer2DModelOutput(sample=output)

View File

@@ -793,7 +793,7 @@ class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
output, kv_cache = self.transformer(
noise_pred, kv_cache = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
@@ -805,7 +805,6 @@ class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
kv_cache_mode="extract",
num_ref_tokens=image_latents.shape[1],
)
noise_pred = output[0]
elif kv_cache is not None:
# Steps 1+: use cached ref KV, no ref tokens in input

View File

@@ -0,0 +1,174 @@
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
from diffusers import (
AutoencoderKLFlux2,
FlowMatchEulerDiscreteScheduler,
Flux2KleinKVPipeline,
Flux2Transformer2DModel,
)
from ...testing_utils import torch_device
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Flux2KleinKVPipeline
params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
supports_dduf = False
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = Flux2Transformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=16,
timestep_guidance_channels=256,
axes_dims_rope=[4, 4, 4, 4],
guidance_embeds=False,
)
# Create minimal Qwen3 config
config = Qwen3Config(
intermediate_size=16,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
vocab_size=151936,
max_position_embeddings=512,
)
torch.manual_seed(0)
text_encoder = Qwen3ForCausalLM(config)
# Use a simple tokenizer for testing
tokenizer = Qwen2TokenizerFast.from_pretrained(
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
)
torch.manual_seed(0)
vae = AutoencoderKLFlux2(
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "a dog is dancing",
"image": Image.new("RGB", (64, 64)),
"generator": generator,
"num_inference_steps": 2,
"height": 8,
"width": 8,
"max_sequence_length": 64,
"output_type": "np",
"text_encoder_out_layers": (1,),
}
return inputs
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
self.assertTrue(
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
("Fusion of QKV projections shouldn't affect the outputs."),
)
self.assertTrue(
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
)
self.assertTrue(
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
("Original outputs should match when fused QKV projections are disabled."),
)
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
self.assertEqual(
(output_height, output_width),
(expected_height, expected_width),
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
)
def test_without_image(self):
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
inputs = self.get_dummy_inputs(device)
del inputs["image"]
image = pipe(**inputs).images
self.assertEqual(image.shape, (1, 8, 8, 3))
@unittest.skip("Needs to be revisited")
def test_encode_prompt_works_in_isolation(self):
pass