Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
ec739c0441 fix klein lora loading. 2026-03-23 12:22:54 +05:30
5 changed files with 275 additions and 551 deletions

View File

@@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
return converted_state_dict
def _convert_kohya_flux2_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
scale = alpha / rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
default_alpha = torch.tensor(
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check if upweight is sparse
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
# Detect number of blocks from keys
num_double_layers = 0
num_single_layers = 0
for key in state_dict.keys():
if key.startswith("lora_unet_double_blocks_"):
block_idx = int(key.split("_")[4])
num_double_layers = max(num_double_layers, block_idx + 1)
elif key.startswith("lora_unet_single_blocks_"):
block_idx = int(key.split("_")[4])
num_single_layers = max(num_single_layers, block_idx + 1)
ait_sd = {}
for i in range(num_double_layers):
# Attention projections
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
# MLP layers (Flux2 uses ff.linear_in/linear_out)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.linear_in",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.linear_out",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.linear_in",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.linear_out",
)
for i in range(num_single_layers):
# Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
)
# Single blocks: linear2 -> attn.to_out
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.attn.to_out",
)
# Handle optional extra keys
extra_mappings = {
"lora_unet_img_in": "transformer.x_embedder",
"lora_unet_txt_in": "transformer.context_embedder",
"lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
"lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
"lora_unet_final_layer_linear": "transformer.proj_out",
}
for sds_key, ait_key in extra_mappings.items():
_convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
remaining_keys = list(state_dict.keys())
if remaining_keys:
logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
return ait_sd
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
"""
Convert non-diffusers ZImage LoRA state dict to diffusers format.

View File

@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux2_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
@@ -5673,6 +5674,13 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
if is_peft_format:
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}

View File

@@ -16,29 +16,22 @@ from typing import Callable
import numpy as np
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import (
is_cosmos_guardrail_available,
is_torch_xla_available,
is_torchvision_available,
logging,
replace_example_docstring,
)
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_torchvision_available():
import torchvision.transforms.functional
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:

View File

@@ -41,6 +41,7 @@ from ..testing_utils import (
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
@@ -218,10 +219,6 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
@@ -415,6 +412,10 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn
"""BitsAndBytes + compile tests for Flux Transformer."""
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""

View File

@@ -13,95 +13,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import Flux2Transformer2DModel
from diffusers.models.transformers.transformer_flux2 import (
Flux2KVAttnProcessor,
Flux2KVCache,
Flux2KVLayerCache,
Flux2KVParallelSelfAttnProcessor,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers import Flux2Transformer2DModel, attention_backend
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class Flux2TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return Flux2Transformer2DModel
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, int]:
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
def output_shape(self):
return (16, 4)
@property
def model_split_percents(self) -> list:
# We override the items here because the transformer under consideration is small.
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
# Skip setting testing with default: AttnProcessor
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 256, # Hardcoded in original code
"axes_dims_rope": [4, 4, 4, 4],
}
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
@@ -129,286 +82,8 @@ class Flux2TransformerTesterConfig(BaseModelTesterConfig):
"guidance": guidance,
}
class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin):
pass
class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux2 Transformer."""
class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux2 Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Flux2Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux2 Transformer."""
class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux2 Transformer."""
class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux2 Transformer."""
class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux2 Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux2 Transformer."""
class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux2 Transformer."""
class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Flux2 Transformer."""
@property
def gguf_filename(self):
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX2 model dimensions.
Flux2 defaults: in_channels=128, joint_attention_dim=15360
"""
batch_size = 1
height = 64
width = 64
sequence_length = 512
hidden_states = randn_tensor(
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
# Flux2 uses 4D image/text IDs (t, h, w, l)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerTorchAoCompile(Flux2TransformerTesterConfig, TorchAoCompileTesterMixin):
"""TorchAO + compile tests for Flux2 Transformer."""
class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Flux2 Transformer."""
@property
def gguf_filename(self):
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX2 model dimensions.
Flux2 defaults: in_channels=128, joint_attention_dim=15360
"""
batch_size = 1
height = 64
width = 64
sequence_length = 512
hidden_states = randn_tensor(
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
# Flux2 uses 4D image/text IDs (t, h, w, l)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
num_ref_tokens = 4
@property
def model_class(self):
return Flux2Transformer2DModel
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -416,210 +91,72 @@ class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 256,
"timestep_guidance_channels": 256, # Hardcoded in original code
"axes_dims_rope": [4, 4, 4, 4],
}
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
num_ref_tokens = self.num_ref_tokens
inputs_dict = self.dummy_input
return init_dict, inputs_dict
ref_hidden_states = randn_tensor(
(batch_size, num_ref_tokens, num_latent_channels), generator=self.generator, device=torch_device
)
img_hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
hidden_states = torch.cat([ref_hidden_states, img_hidden_states], dim=1)
# TODO (Daniel, Sayak): We can remove this test.
def test_flux2_consistency(self, seed=0):
torch.manual_seed(seed)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
ref_t_coords = torch.arange(1)
ref_h_coords = torch.arange(num_ref_tokens)
ref_w_coords = torch.arange(1)
ref_l_coords = torch.arange(1)
ref_ids = torch.cartesian_prod(ref_t_coords, ref_h_coords, ref_w_coords, ref_l_coords)
ref_ids = ref_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
image_ids = torch.cat([ref_ids, image_ids], dim=1)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerKVCache(Flux2TransformerKVCacheTesterConfig):
"""KV cache tests for Flux2 Transformer."""
def test_kv_layer_cache_store_and_get(self):
cache = Flux2KVLayerCache()
k = torch.randn(1, 4, 2, 16)
v = torch.randn(1, 4, 2, 16)
cache.store(k, v)
k_out, v_out = cache.get()
assert torch.equal(k, k_out)
assert torch.equal(v, v_out)
def test_kv_layer_cache_get_before_store_raises(self):
cache = Flux2KVLayerCache()
try:
cache.get()
assert False, "Expected RuntimeError"
except RuntimeError:
pass
def test_kv_layer_cache_clear(self):
cache = Flux2KVLayerCache()
cache.store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
cache.clear()
assert cache.k_ref is None
assert cache.v_ref is None
def test_kv_cache_structure(self):
num_double = 3
num_single = 2
cache = Flux2KVCache(num_double, num_single)
assert len(cache.double_block_caches) == num_double
assert len(cache.single_block_caches) == num_single
assert cache.num_ref_tokens == 0
for i in range(num_double):
assert isinstance(cache.get_double(i), Flux2KVLayerCache)
for i in range(num_single):
assert isinstance(cache.get_single(i), Flux2KVLayerCache)
def test_kv_cache_clear(self):
cache = Flux2KVCache(2, 1)
cache.num_ref_tokens = 4
cache.get_double(0).store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
cache.clear()
assert cache.num_ref_tokens == 0
assert cache.get_double(0).k_ref is None
def _set_kv_attn_processors(self, model):
for block in model.transformer_blocks:
block.attn.set_processor(Flux2KVAttnProcessor())
for block in model.single_transformer_blocks:
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
@torch.no_grad()
def test_extract_mode_returns_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
self._set_kv_attn_processors(model)
output = model(
**self.get_dummy_inputs(),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
assert output.kv_cache is not None
assert isinstance(output.kv_cache, Flux2KVCache)
assert output.kv_cache.num_ref_tokens == self.num_ref_tokens
for layer_cache in output.kv_cache.double_block_caches:
assert layer_cache.k_ref is not None
assert layer_cache.v_ref is not None
for layer_cache in output.kv_cache.single_block_caches:
assert layer_cache.k_ref is not None
assert layer_cache.v_ref is not None
@torch.no_grad()
def test_extract_mode_output_shape(self):
model = self.model_class(**self.get_init_dict())
torch.manual_seed(seed)
model = self.model_class(**init_dict)
# state_dict = model.state_dict()
# for key, param in state_dict.items():
# print(f"{key} | {param.shape}")
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
model.to(torch_device)
model.eval()
height, width = 4, 4
output = model(
**self.get_dummy_inputs(height=height, width=width),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
with attention_backend("native"):
with torch.no_grad():
output = model(**inputs_dict)
assert output.sample.shape == (1, height * width, 4)
if isinstance(output, dict):
output = output.to_tuple()[0]
@torch.no_grad()
def test_cached_mode_uses_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
self.assertIsNotNone(output)
height, width = 4, 4
extract_output = model(
**self.get_dummy_inputs(height=height, width=width),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
# input & output have to have the same shape
input_tensor = inputs_dict[self.main_input_name]
expected_shape = input_tensor.shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
base_config = Flux2TransformerTesterConfig()
cached_inputs = base_config.get_dummy_inputs(height=height, width=width)
cached_output = model(
**cached_inputs,
kv_cache=extract_output.kv_cache,
kv_cache_mode="cached",
)
# Check against expected slice
# fmt: off
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
# fmt: on
assert cached_output.sample.shape == (1, height * width, 4)
assert cached_output.kv_cache is None
flat_output = output.cpu().flatten()
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
@torch.no_grad()
def test_extract_return_dict_false(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Flux2Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
output = model(
**self.get_dummy_inputs(),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
return_dict=False,
)
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[1], Flux2KVCache)
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
@torch.no_grad()
def test_no_kv_cache_mode_returns_no_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
def prepare_init_args_and_inputs_for_common(self):
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
base_config = Flux2TransformerTesterConfig()
output = model(**base_config.get_dummy_inputs())
def prepare_dummy_input(self, height, width):
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
assert output.kv_cache is None
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)