Compare commits

..

19 Commits

Author SHA1 Message Date
DN6
6caa0a9bf4 update 2026-01-08 12:21:13 +05:30
DN6
ba475eee8d update 2026-01-08 12:21:13 +05:30
Sayak Paul
e0ab03d79b Merge branch 'main' into model-test-refactor 2025-12-31 21:03:32 +05:30
DN6
7b3ef42a01 update 2025-12-26 12:45:30 +05:30
DN6
c70de2bc37 update 2025-12-18 13:18:54 +05:30
DN6
e82001e40d update 2025-12-18 13:16:50 +05:30
DN6
d9b73ffd51 update 2025-12-15 16:12:50 +05:30
DN6
dcd6026d17 update 2025-12-15 16:12:15 +05:30
DN6
eae7543712 update 2025-12-15 16:02:38 +05:30
DN6
d08e0bb545 update 2025-12-15 14:19:27 +05:30
DN6
c366b5a817 update 2025-12-11 13:37:06 +05:30
DN6
0fdd9d3a60 update 2025-12-11 11:41:17 +05:30
DN6
489480b02a update 2025-12-11 11:27:59 +05:30
DN6
fe451c367b update 2025-12-11 11:04:47 +05:30
DN6
0f1a4e0c14 update 2025-11-19 21:59:20 +05:30
DN6
aa29af8f0e update 2025-11-19 08:51:38 +05:30
DN6
bffa3a9754 update 2025-11-14 15:48:19 +05:30
DN6
1c558712e8 Merge branch 'main' into model-test-refactor 2025-11-12 10:18:07 +05:30
DN6
1f026ad14e update 2025-11-12 10:17:54 +05:30
31 changed files with 5836 additions and 1025 deletions

View File

@@ -1,22 +0,0 @@
---
name: CodeQL Security Analysis For Github Actions
on:
push:
branches: ["main"]
workflow_dispatch:
# pull_request:
jobs:
codeql:
name: CodeQL Analysis
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
permissions:
security-events: write
packages: read
actions: read
contents: read
with:
languages: '["actions","python"]'
queries: 'security-extended,security-and-quality'
runner: 'ubuntu-latest' #optional if need custom runner

View File

@@ -24,6 +24,7 @@ jobs:
mirror_community_pipeline:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
runs-on: ubuntu-22.04
steps:
# Checkout to correct ref
@@ -38,28 +39,25 @@ jobs:
# If ref is 'refs/heads/main' => set 'main'
# Else it must be a tag => set {tag}
- name: Set checkout_ref and path_in_repo
EVENT_NAME: ${{ github.event_name }}
EVENT_INPUT_REF: ${{ github.event.inputs.ref }}
GITHUB_REF: ${{ github.ref }}
run: |
if [ "$EVENT_NAME" == "workflow_dispatch" ]; then
if [ -z "$EVENT_INPUT_REF" ]; then
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
if [ -z "${{ github.event.inputs.ref }}" ]; then
echo "Error: Missing ref input"
exit 1
elif [ "$EVENT_INPUT_REF" == "main" ]; then
elif [ "${{ github.event.inputs.ref }}" == "main" ]; then
echo "CHECKOUT_REF=refs/heads/main" >> $GITHUB_ENV
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
else
echo "CHECKOUT_REF=refs/tags/$EVENT_INPUT_REF" >> $GITHUB_ENV
echo "PATH_IN_REPO=$EVENT_INPUT_REF" >> $GITHUB_ENV
echo "CHECKOUT_REF=refs/tags/${{ github.event.inputs.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=${{ github.event.inputs.ref }}" >> $GITHUB_ENV
fi
elif [ "$GITHUB_REF" == "refs/heads/main" ]; then
echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
else
# e.g. refs/tags/v0.28.1 -> v0.28.1
echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
echo "PATH_IN_REPO=$(echo $GITHUB_REF | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
fi
- name: Print env vars
run: |
@@ -101,4 +99,4 @@ jobs:
- name: Report failure status
if: ${{ failure() }}
run: |
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure

View File

@@ -263,8 +263,8 @@ def main():
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)

View File

@@ -1,844 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import AutoTokenizer, PreTrainedModel
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import ZImageTransformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from pipeline_z_image_differential_img2img import ZImageDifferentialImg2ImgPipeline
>>> from diffusers.utils import load_image
>>> pipe = ZImageDifferentialImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> init_image = load_image(
>>> "https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true",
>>> )
>>> mask = load_image(
>>> "https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true",
>>> )
>>> prompt = "painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art"
>>> image = pipe(
... prompt,
... image=init_image,
... mask_image=mask,
... strength=0.75,
... num_inference_steps=9,
... guidance_scale=0.0,
... generator=torch.Generator("cuda").manual_seed(41),
... ).images[0]
>>> image.save("image.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ZImageDifferentialImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
r"""
The ZImage pipeline for image-to-image generation.
Args:
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`PreTrainedModel`]):
A text encoder model to encode text prompts.
tokenizer ([`AutoTokenizer`]):
A tokenizer to tokenize text prompts.
transformer ([`ZImageTransformer2DModel`]):
A ZImage transformer model to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: PreTrainedModel,
tokenizer: AutoTokenizer,
transformer: ZImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=False,
do_convert_grayscale=True,
)
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = ["" for _ in prompt]
else:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
assert len(prompt) == len(negative_prompt)
negative_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
device=device,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
)
else:
negative_prompt_embeds = []
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
def _encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
device = device or self._execution_device
if prompt_embeds is not None:
return prompt_embeds
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype)
# Encode the input image
image = image.to(device=device, dtype=dtype)
if image.shape[1] != num_channels_latents:
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
# Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
else:
image_latents = image
# Handle batch size expansion
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
# Add noise using flow matching scale_noise
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
return latents, noise, image_latents, latent_image_ids
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
num_images_per_prompt,
height,
width,
dtype,
device,
generator,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height, width))
mask = mask.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
masked_image = masked_image.to(device=device, dtype=dtype)
if masked_image.shape[1] == 16:
masked_image_latents = masked_image
else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
strength: float = 0.6,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
cfg_normalization: bool = False,
cfg_truncation: float = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for image-to-image generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. Black pixels in the mask
are repainted while white pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
strength (`float`, *optional*, defaults to 0.6):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image. If not provided, uses the input image height.
width (`int`, *optional*, defaults to 1024):
The width in pixels of the generated image. If not provided, uses the input image width.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
cfg_normalization (`bool`, *optional*, defaults to False):
Whether to apply configuration normalization.
cfg_truncation (`float`, *optional*, defaults to 1.0):
The truncation value for configuration.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, *optional*, defaults to 512):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
# 1. Check inputs and validate strength
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
# 2. Preprocess image
init_image = self.image_processor.preprocess(image)
init_image = init_image.to(dtype=torch.float32)
# Get dimensions from the preprocessed image if not specified
if height is None:
height = init_image.shape[-2]
if width is None:
width = init_image.shape[-1]
vae_scale = self.vae_scale_factor * 2
if height % vae_scale != 0:
raise ValueError(
f"Height must be divisible by {vae_scale} (got {height}). "
f"Please adjust the height to a multiple of {vae_scale}."
)
if width % vae_scale != 0:
raise ValueError(
f"Width must be divisible by {vae_scale} (got {width}). "
f"Please adjust the width to a multiple of {vae_scale}."
)
device = self._execution_device
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
self._cfg_normalization = cfg_normalization
self._cfg_truncation = cfg_truncation
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = len(prompt_embeds)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is not None and prompt is None:
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
raise ValueError(
"When `prompt_embeds` is provided without `prompt`, "
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
)
else:
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels
# Repeat prompt_embeds for num_images_per_prompt
if num_images_per_prompt > 1:
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
if self.do_classifier_free_guidance and negative_prompt_embeds:
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
actual_batch_size = batch_size * num_images_per_prompt
# Calculate latent dimensions for image_seq_len
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
# 5. Prepare timesteps
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
self.scheduler.sigma_min = 0.0
scheduler_kwargs = {"mu": mu}
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
# 6. Adjust timesteps based on strength
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(actual_batch_size)
# 7. Prepare latents from image
latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
actual_batch_size,
num_channels_latents,
height,
width,
prompt_embeds[0].dtype,
device,
generator,
latents,
)
resize_mode = "default"
crops_coords = None
# start diff diff preparation
original_mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
masked_image = init_image * original_mask
original_mask, _ = self.prepare_mask_latents(
original_mask,
masked_image,
batch_size,
num_images_per_prompt,
height,
width,
prompt_embeds[0].dtype,
device,
generator,
)
mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps
mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device)
masks = original_mask > mask_thresholds
# end diff diff preparation
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
timestep = (1000 - timestep) / 1000
# Normalized time for time-aware config (0 at start, 1 at end)
t_norm = timestep[0].item()
# Handle cfg truncation
current_guidance_scale = self.guidance_scale
if (
self.do_classifier_free_guidance
and self._cfg_truncation is not None
and float(self._cfg_truncation) <= 1
):
if t_norm > self._cfg_truncation:
current_guidance_scale = 0.0
# Run CFG only if configured AND scale is non-zero
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
if apply_cfg:
latents_typed = latents.to(self.transformer.dtype)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents.to(self.transformer.dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
model_out_list = self.transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
)[0]
if apply_cfg:
# Perform CFG
pos_out = model_out_list[:actual_batch_size]
neg_out = model_out_list[actual_batch_size:]
noise_pred = []
for j in range(actual_batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()
pred = pos + current_guidance_scale * (pos - neg)
# Renormalization
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(pos)
new_pos_norm = torch.linalg.vector_norm(pred)
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
if new_pos_norm > max_new_norm:
pred = pred * (max_new_norm / new_pos_norm)
noise_pred.append(pred)
noise_pred = torch.stack(noise_pred, dim=0)
else:
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred = noise_pred.squeeze(2)
noise_pred = -noise_pred
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
assert latents.dtype == torch.float32
# start diff diff
image_latent = original_image_latents
latents_dtype = latents.dtype
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.scale_noise(
original_image_latents, torch.tensor([noise_timestep]), noise
)
mask = masks[i].to(latents_dtype)
latents = image_latent * mask + latents * (1 - mask)
# end diff diff
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ZImagePipelineOutput(images=image)

View File

@@ -274,7 +274,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.37.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -162,7 +162,7 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
"Flux2Transformer2DModel": {

View File

@@ -120,10 +120,7 @@ CHECKPOINT_KEY_NAMES = {
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
"z-image-turbo": [
"model.diffusion_model.layers.0.adaLN_modulation.0.weight",
"layers.0.adaLN_modulation.0.weight",
],
"z-image-turbo": "cap_embedder.0.weight",
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
"sana": [
@@ -226,8 +223,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
"z-image-turbo-controlnet-2.x": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
}
# Use to configure model sample size when original config is provided
@@ -731,7 +727,10 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "instruct-pix2pix"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]):
elif (
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
):
model_type = "z-image-turbo"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
@@ -785,13 +784,7 @@ def infer_diffusers_model_type(checkpoint):
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
before_proj_weight = checkpoint.get("control_noise_refiner.0.before_proj.weight", None)
if before_proj_weight is None:
model_type = "z-image-turbo-controlnet-2.0"
elif before_proj_weight is not None and torch.all(before_proj_weight == 0.0):
model_type = "z-image-turbo-controlnet-2.0"
else:
model_type = "z-image-turbo-controlnet-2.1"
model_type = "z-image-turbo-controlnet-2.x"
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
model_type = "z-image-turbo-controlnet"
@@ -3859,7 +3852,6 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
".attention.k_norm.weight": ".attention.norm_k.weight",
".attention.q_norm.weight": ".attention.norm_q.weight",
".attention.out.weight": ".attention.to_out.0.weight",
"model.diffusion_model.": "",
}
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
@@ -3894,9 +3886,6 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
update_state_dict(converted_state_dict, key, new_key)
if "norm_final.weight" in converted_state_dict.keys():
_ = converted_state_dict.pop("norm_final.weight")
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):

View File

@@ -27,7 +27,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -410,7 +410,7 @@ class HunyuanImageDecoder2D(nn.Module):
return h
class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
@@ -486,6 +486,27 @@ class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromO
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape

View File

@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -584,7 +584,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
@@ -685,6 +685,27 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape

View File

@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -625,7 +625,7 @@ class HunyuanVideo15Decoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanVideo-1.5.
@@ -723,6 +723,27 @@ class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape

View File

@@ -134,8 +134,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=None,
parallel_config=self._parallel_config,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -148,8 +147,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
@@ -554,11 +552,9 @@ class WanTransformer3DModel(
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
# Reference: https://github.com/huggingface/diffusers/pull/12909
# We need to disable the splitting of encoder_hidden_states because the image_encoder
# (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape
# of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation
# —to be indivisible by the number of devices in the CP.
"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
"": {
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),

View File

@@ -609,8 +609,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=None,
parallel_config=self._parallel_config,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -623,8 +622,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)

View File

@@ -76,7 +76,7 @@ EXAMPLE_DOC_STRING = """
>>> model_id = "nvidia/Cosmos-Predict2.5-2B"
>>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
... model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16
... )
>>> pipe = pipe.to("cuda")

View File

@@ -58,13 +58,14 @@ EXAMPLE_DOC_STRING = """
>>> # torch_dtype=torch.bfloat16,
>>> # )
>>> # 2.0
>>> # 2.0 - `config` is required
>>> # controlnet = ZImageControlNetModel.from_single_file(
>>> # hf_hub_download(
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
>>> # ),
>>> # torch_dtype=torch.bfloat16,
>>> # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
>>> # )
>>> pipe = ZImageControlNetPipeline.from_pretrained(

View File

@@ -50,13 +50,14 @@ EXAMPLE_DOC_STRING = """
... torch_dtype=torch.bfloat16,
... )
>>> # 2.0
>>> # 2.0 - `config` is required
>>> # controlnet = ZImageControlNetModel.from_single_file(
>>> # hf_hub_download(
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
>>> # ),
>>> # torch_dtype=torch.bfloat16,
>>> # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
>>> # )
>>> pipe = ZImageControlNetInpaintPipeline.from_pretrained(

View File

@@ -32,6 +32,21 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
def pytest_addoption(parser):

View File

@@ -0,0 +1,79 @@
from .attention import AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
FasterCacheTesterMixin,
FirstBlockCacheConfigMixin,
FirstBlockCacheTesterMixin,
PyramidAttentionBroadcastConfigMixin,
PyramidAttentionBroadcastTesterMixin,
)
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFConfigMixin,
GGUFTesterMixin,
ModelOptCompileTesterMixin,
ModelOptConfigMixin,
ModelOptTesterMixin,
QuantizationCompileTesterMixin,
QuantizationTesterMixin,
QuantoCompileTesterMixin,
QuantoConfigMixin,
QuantoTesterMixin,
TorchAoCompileTesterMixin,
TorchAoConfigMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
from .training import TrainingTesterMixin
__all__ = [
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",
"BitsAndBytesConfigMixin",
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",
"FirstBlockCacheConfigMixin",
"FirstBlockCacheTesterMixin",
"GGUFCompileTesterMixin",
"GGUFConfigMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraHotSwappingForModelTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptCompileTesterMixin",
"ModelOptConfigMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"PyramidAttentionBroadcastConfigMixin",
"PyramidAttentionBroadcastTesterMixin",
"QuantizationCompileTesterMixin",
"QuantizationTesterMixin",
"QuantoCompileTesterMixin",
"QuantoConfigMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoCompileTesterMixin",
"TorchAoConfigMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",
]

View File

@@ -0,0 +1,171 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import (
assert_tensors_close,
is_attention,
torch_device,
)
@is_attention
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
"""
@torch.no_grad()
def test_fuse_unfuse_qkv_projections(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
model.fuse_qkv_projections()
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break
if has_fused_projections:
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=1e-3,
rtol=0,
msg="Output should not change after fusing projections",
)
model.unfuse_qkv_projections()
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=1e-3,
rtol=0,
msg="Output should match original after unfusing projections",
)
def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")
# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"
# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"
# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
# Set processors using dict
model.set_attn_processor(new_processors)
# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"

View File

@@ -0,0 +1,530 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from diffusers.models.cache_utils import CacheMixin
from ...testing_utils import backend_empty_cache, is_cache, torch_device
def require_cache_mixin(func):
"""Decorator to skip tests if model doesn't use CacheMixin."""
def wrapper(self, *args, **kwargs):
if not issubclass(self.model_class, CacheMixin):
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
return func(self, *args, **kwargs)
return wrapper
class CacheTesterMixin:
"""
Base mixin class providing common test implementations for cache testing.
Cache-specific mixins should:
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
2. Inherit from this mixin
3. Define the cache config to use for tests
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods in test classes:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def _get_cache_config(self):
"""
Get the cache config for testing.
Should be implemented by subclasses.
"""
raise NotImplementedError("Subclass must implement _get_cache_config")
def _get_hook_names(self):
"""
Get the hook names to check for this cache type.
Should be implemented by subclasses.
Returns a list of hook name strings.
"""
raise NotImplementedError("Subclass must implement _get_hook_names")
def _test_cache_enable_disable_state(self):
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Initially cache should not be enabled
assert not model.is_cache_enabled, "Cache should not be enabled initially."
config = self._get_cache_config()
# Enable cache
model.enable_cache(config)
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
# Disable cache
model.disable_cache()
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
def _test_cache_double_enable_raises_error(self):
"""Test that enabling cache twice raises an error."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
model.enable_cache(config)
# Trying to enable again should raise ValueError
with pytest.raises(ValueError, match="Caching has already been enabled"):
model.enable_cache(config)
# Cleanup
model.disable_cache()
def _test_cache_hooks_registered(self):
"""Test that cache hooks are properly registered and removed."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
hook_names = self._get_hook_names()
model.enable_cache(config)
# Check that at least one hook was registered
hook_count = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count += 1
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
# Disable and verify hooks are removed
model.disable_cache()
hook_count_after = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count_after += 1
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with cache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (vary hidden_states to simulate denoising)
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
# Second pass uses cached attention with different hidden_states (produces approximated output)
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
def _test_cache_context_manager(self):
"""Test the cache_context context manager."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
model.enable_cache(config)
# Test cache_context works without error
with model.cache_context("test_context"):
pass
model.disable_cache()
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@is_cache
class PyramidAttentionBroadcastConfigMixin:
"""
Base mixin providing PyramidAttentionBroadcast cache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default PAB config - can be overridden by subclasses
PAB_CONFIG = {
"spatial_attention_block_skip_range": 2,
}
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
_current_timestep = 500
def _get_cache_config(self):
config_kwargs = self.PAB_CONFIG.copy()
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
return PyramidAttentionBroadcastConfig(**config_kwargs)
def _get_hook_names(self):
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
@is_cache
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
"""
Mixin class for testing PyramidAttentionBroadcast caching on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@require_cache_mixin
def test_pab_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_pab_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_pab_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_pab_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_pab_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_pab_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FirstBlockCacheConfigMixin:
"""
Base mixin providing FirstBlockCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FBC config - can be overridden by subclasses
# Higher threshold makes FBC more aggressive about caching (skips more often)
FBC_CONFIG = {
"threshold": 1.0,
}
def _get_cache_config(self):
return FirstBlockCacheConfig(**self.FBC_CONFIG)
def _get_hook_names(self):
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
@is_cache
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FirstBlockCache on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# FBC requires cache_context to be set for inference
with model.cache_context("fbc_test"):
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (small perturbation keeps residuals similar)
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01
# Second pass - FBC should skip remaining blocks and use cached residuals
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
with model.cache_context("fbc_test"):
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_fbc_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_fbc_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_fbc_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_fbc_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_fbc_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_fbc_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FasterCacheConfigMixin:
"""
Base mixin providing FasterCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FasterCache config - can be overridden by subclasses
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
}
# Store timestep for callback - use a list so it can be mutated during test
# Starts outside skip range so first pass computes; changed to inside range for subsequent passes
_current_timestep = [1000]
def _get_cache_config(self):
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0]
return FasterCacheConfig(**config_kwargs)
def _get_hook_names(self):
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
@is_cache
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FasterCache on models.
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
and timestep management. Inference tests are skipped at model level - FasterCache should
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FasterCache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# First pass with timestep outside skip range - computes and populates cache
self._current_timestep[0] = 1000
_ = model(**inputs_dict, return_dict=False)[0]
# Move timestep inside skip range so subsequent passes use cache
self._current_timestep[0] = 500
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
# Second pass uses cached attention with different hidden_states
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FasterCache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
self._current_timestep[0] = 1000
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_faster_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_faster_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_faster_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_faster_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_faster_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_faster_cache_reset_stateful_cache(self):
self._test_reset_stateful_cache()

View File

@@ -0,0 +1,646 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections import defaultdict
from typing import Any, Dict, Optional, Type
import pytest
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
from ...testing_utils import assert_tensors_close, torch_device
def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer
def compute_module_persistent_sizes(
model: nn.Module,
dtype: str | torch.device | None = None,
special_dtypes: dict[str, str | torch.device] | None = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
def calculate_expected_num_shards(index_map_path):
"""
Calculate expected number of shards from index file.
Args:
index_map_path: Path to the sharded checkpoint index file
Returns:
int: Expected number of shards
"""
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
return expected_num_shards
def check_device_map_is_respected(model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
else:
assert param.device == torch.device(param_device), (
f"Expected device {param_device} for {param_name}, got {param.device}"
)
class BaseModelTesterConfig:
"""
Base class defining the configuration interface for model testing.
This class defines the contract that all model test classes must implement.
It provides a consistent interface for accessing model configuration, initialization
parameters, and test inputs across all testing mixins.
Required properties (must be implemented by subclasses):
- model_class: The model class to test
Optional properties (can be overridden, have sensible defaults):
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
- output_shape: Expected output shape for output validation tests (default: None)
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
Required methods (must be implemented by subclasses):
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example usage:
class MyModelTestConfig(BaseModelTesterConfig):
@property
def model_class(self):
return MyModel
@property
def pretrained_model_name_or_path(self):
return "org/my-model"
@property
def output_shape(self):
return (1, 3, 32, 32)
def get_init_dict(self):
return {"in_channels": 3, "out_channels": 3}
def get_dummy_inputs(self):
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
pass
"""
# ==================== Required Properties ====================
@property
def model_class(self) -> Type[nn.Module]:
"""The model class to test. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `model_class` property.")
# ==================== Optional Properties ====================
@property
def pretrained_model_name_or_path(self) -> Optional[str]:
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
return None
@property
def pretrained_model_kwargs(self) -> Dict[str, Any]:
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
return {}
@property
def output_shape(self) -> Optional[tuple]:
"""Expected output shape for output validation tests."""
return None
@property
def model_split_percents(self) -> list:
"""Percentages for model parallelism tests."""
return [0.5, 0.7]
# ==================== Required Methods ====================
def get_init_dict(self) -> Dict[str, Any]:
"""
Returns dict of arguments to initialize the model.
Returns:
Dict[str, Any]: Initialization arguments for the model constructor.
Example:
return {
"in_channels": 3,
"out_channels": 3,
"sample_size": 32,
}
"""
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
def get_dummy_inputs(self) -> Dict[str, Any]:
"""
Returns dict of inputs to pass to the model forward pass.
Returns:
Dict[str, Any]: Input tensors/values for model.forward().
Example:
return {
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
"timestep": torch.tensor([1], device=torch_device),
}
"""
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
class ModelTesterMixin:
"""
Base mixin class for model testing with common test methods.
This mixin expects the test class to also inherit from BaseModelTesterConfig
(or implement its interface) which provides:
- model_class: The model class to test
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example:
class MyModelTestConfig(BaseModelTesterConfig):
model_class = MyModel
def get_init_dict(self): ...
def get_dummy_inputs(self): ...
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
pass
"""
@torch.no_grad()
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path)
new_model.to(torch_device)
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert param_1.shape == param_2.shape, (
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path, variant="fp16")
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmp_path)
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
if torch_device == "mps" and dtype == torch.bfloat16:
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
model.to(dtype)
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype
@torch.no_grad()
def test_determinism(self, atol=1e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
first_flat = first.flatten()
second_flat = second.flatten()
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
first_filtered = first_flat[mask]
second_filtered = second_flat[mask]
assert_tensors_close(
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
)
@torch.no_grad()
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None, "Model output is None"
assert output[0].shape == expected_output_shape or self.output_shape, (
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
)
@torch.no_grad()
def test_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t.to(device)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (list, tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
assert_tensors_close(
set_nan_tensor_to_zero(tuple_object),
set_nan_tensor_to_zero(dict_object),
atol=1e-5,
rtol=0,
msg="Tuple and dict output are not equal",
)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
def test_getattr_is_correct(self, caplog):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
# save some things to test
model.dummy_attribute = 5
model.register_to_config(test_attribute=5)
logger_name = "diffusers.models.modeling_utils"
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "dummy_attribute")
assert getattr(model, "dummy_attribute") == 5
assert model.dummy_attribute == 5
# no warning should be thrown
assert caplog.text == ""
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "save_pretrained")
fn = model.save_pretrained
fn_1 = getattr(model, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert caplog.text == ""
# warning should be thrown for config attributes accessed directly
with pytest.warns(FutureWarning):
assert model.test_attribute == 5
with pytest.warns(FutureWarning):
assert getattr(model, "test_attribute") == 5
with pytest.raises(AttributeError) as error:
model.does_not_exist
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
# Test with float16
model.to(torch_device)
model.to(torch.float16)
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
else:
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules
model.to(dtype).save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
for name, param in model_loaded.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == dtype
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")
@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints(self, tmp_path):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_path).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
)
@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints_with_variant(self, tmp_path):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
f"Variant index file {index_filename} should exist"
)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
)
@torch.no_grad()
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
from diffusers.utils import constants
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
# Save original values to restore after test
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
try:
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
model_sequential = model_sequential.to(torch_device)
# Load with parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
torch.manual_seed(0)
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
assert_tensors_close(
base_output, output_parallel, atol=1e-5, rtol=0, msg="Output should match with parallel loading"
)
finally:
# Restore original values
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
if original_parallel_workers is not None:
constants.HF_PARALLEL_WORKERS = original_parallel_workers
@require_torch_multi_accelerator
@torch.no_grad()
def test_model_parallelism(self, tmp_path):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
model.cpu().save_pretrained(tmp_path)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
# Making sure part of the model will be on GPU 0 and GPU 1
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match with model parallelism"
)

View File

@@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import pytest
import torch
from ...testing_utils import (
backend_empty_cache,
is_torch_compile,
require_accelerator,
require_torch_version_greater,
torch_device,
)
@is_torch_compile
@require_accelerator
@require_torch_version_greater("2.7.1")
class TorchCompileTesterMixin:
"""
Mixin class for testing torch.compile functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: compile
Use `pytest -m "not compile"` to skip these tests
"""
different_shapes_for_compilation = None
def setup_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_torch_compile_recompilation_and_graph_break(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_with_group_offloading(self):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch._dynamo.config.cache_size_limit = 10000
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.eval()
group_offload_kwargs = {
"onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
"use_stream": True,
"non_blocking": True,
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
torch.fx.experimental._config.use_duck_shape = False
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True):
inputs_dict = self.get_dummy_inputs(height=height, width=width)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_works_with_aot(self, tmp_path):
from torch._inductor.package import load_package
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path), f"Package file not created at {package_path}"
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary
_ = model(**inputs_dict)
_ = model(**inputs_dict)

View File

@@ -0,0 +1,139 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from ...testing_utils import is_ip_adapter, torch_device
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if IP Adapter is correctly set, False otherwise
"""
for module in model.attn_processors.values():
if isinstance(module, processor_cls):
return True
return False
@is_ip_adapter
class IPAdapterTesterMixin:
"""
Mixin class for testing IP Adapter functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: ip_adapter
Use `pytest -m "not ip_adapter"` to skip these tests
"""
ip_adapter_processor_cls = None
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
@torch.no_grad()
def test_load_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter processors not set correctly"
)
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
"Output should differ with IP Adapter enabled"
)
@pytest.mark.skip(
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
)
def test_ip_adapter_scale(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
# Test scale = 0.0 (no effect)
model.set_ip_adapter_scale(0.0)
torch.manual_seed(0)
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Test scale = 1.0 (full effect)
model.set_ip_adapter_scale(1.0)
torch.manual_seed(0)
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should differ with different scales
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
"Output should differ with different IP Adapter scales"
)
@pytest.mark.skip(
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
)
def test_unload_ip_adapter(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Save original processors
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
# Unload IP adapter
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter should be unloaded"
)
# Verify processors are restored
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
assert original_processors == current_processors, "Processors should be restored after unload"

View File

@@ -0,0 +1,549 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import re
import pytest
import safetensors.torch
import torch
import torch.nn as nn
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_lora,
is_torch_compile,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_torch_version_greater,
torch_device,
)
if is_peft_available():
from diffusers.loaders.peft import PeftAdapterMixin
def check_if_lora_correctly_set(model) -> bool:
"""
Check if LoRA layers are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if LoRA is correctly set, False otherwise
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
@is_lora
@require_peft_backend
class LoraTesterMixin:
"""
Mixin class for testing LoRA/PEFT functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: lora
Use `pytest -m "not lora"` to skip these tests
"""
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
@torch.no_grad()
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4), (
"Output should differ with LoRA enabled"
)
model.save_lora_adapter(tmp_path)
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
"LoRA weights file not created"
)
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert_tensors_close(loaded_v, retrieved_v, atol=1e-5, rtol=0, msg=f"Mismatch in LoRA weight {k}")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
"Output should differ with LoRA enabled"
)
assert_tensors_close(
outputs_with_lora,
outputs_with_lora_2,
atol=1e-4,
rtol=1e-4,
msg="Outputs should match before and after save/load",
)
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
wrong_name = "foo"
with pytest.raises(ValueError) as exc_info:
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
# Perturb the metadata in the state dict
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
with pytest.raises(TypeError) as exc_info:
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
@is_lora
@is_torch_compile
@require_peft_backend
@require_peft_version_greater("0.14.0")
@require_torch_version_greater("2.7.1")
@require_torch_accelerator
class LoraHotSwappingForModelTesterMixin:
"""
Mixin class for testing LoRA hot swapping functionality on models.
Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
and is extensively tested there. The goal of this test is specifically to ensure that
hotswapping with diffusers does not require recompilation.
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
for the analogous PEFT test.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest marks: lora, torch_compile
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
"""
different_shapes_for_compilation = None
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
def teardown_method(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
from peft import LoraConfig
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return lora_config
def _get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]
def _check_model_hotswap(self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None):
"""
Check that hotswapping works on a model.
Steps:
- create 2 LoRA adapters and save them
- load the first adapter
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
- optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]
# sanity checks:
tol = 5e-3
assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
assert not (output0_before == 0).all()
assert not (output1_before == 0).all()
# save the adapter checkpoints
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
del model
# load the first adapter
torch.manual_seed(0)
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
model.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
# additionally check if dynamic compilation works.
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output0_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output0_before, output0_after, atol=tol, rtol=tol, msg="Output mismatch after loading adapter0"
)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output1_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output1_before,
output1_after,
atol=tol,
rtol=tol,
msg="Output mismatch after hotswapping to adapter1",
)
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with pytest.raises(ValueError, match=re.escape(msg)):
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_model(self, tmp_path, rank0, rank1):
self._check_model_hotswap(
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
target_modules = ["to_q"]
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
target_modules.append(self._get_linear_module_name_other_than_attn(model))
del model
# It's important to add this context to raise an error on recompilation
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with pytest.raises(ValueError, match=msg):
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
# check the error and log
import logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
different_shapes_for_compilation = self.different_shapes_for_compilation
if different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
# variable to represent input sizes that are the same. For more details,
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
torch.fx.experimental._config.use_duck_shape = False
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=rank0,
rank1=rank1,
target_modules0=target_modules,
)

View File

@@ -0,0 +1,501 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import glob
import inspect
from functools import wraps
import pytest
import torch
from accelerate.utils.modeling import compute_module_sizes
from diffusers.utils.testing_utils import _check_safetensors_serialization
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
is_cpu_offload,
is_group_offload,
is_memory,
require_accelerator,
torch_device,
)
from .common import check_device_map_is_respected
def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype):
"""Helper to cast tensor inputs from one dtype to another."""
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor) and value.dtype == from_dtype:
inputs_dict[key] = value.to(to_dtype)
return inputs_dict
def require_offload_support(func):
"""
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
return func(self, *args, **kwargs)
return wrapper
def require_group_offload_support(func):
"""
Decorator to skip tests if model doesn't support group offloading.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
return func(self, *args, **kwargs)
return wrapper
@is_cpu_offload
class CPUOffloadTesterMixin:
"""
Mixin class for testing CPU offloading functionality.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- model_split_percents: List of percentages for splitting model across devices
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cpu_offload
Use `pytest -m "not cpu_offload"` to skip these tests
"""
model_split_percents = [0.5, 0.7]
@require_offload_support
@torch.no_grad()
def test_cpu_offload(self, tmp_path):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
model.cpu().save_pretrained(str(tmp_path))
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
)
@require_offload_support
@torch.no_grad()
def test_disk_offload_without_safetensors(self, tmp_path):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
# This errors out because it's missing an offload folder
with pytest.raises(ValueError):
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
)
@require_offload_support
@torch.no_grad()
def test_disk_offload_with_safetensors(self, tmp_path):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model.cpu().save_pretrained(str(tmp_path))
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0],
new_output[0],
atol=1e-5,
rtol=0,
msg="Output should match with disk offloading (safetensors)",
)
@is_group_offload
class GroupOffloadTesterMixin:
"""
Mixin class for testing group offloading functionality.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: group_offload
Use `pytest -m "not group_offload"` to skip these tests
"""
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
def test_group_offloading(self, record_stream):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
@torch.no_grad()
def run_forward(model):
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
), "Group offloading hook should be set"
model.eval()
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
model.to(torch_device)
output_without_group_offloading = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading1,
atol=1e-5,
rtol=0,
msg="Output should match with block-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading2,
atol=1e-5,
rtol=0,
msg="Output should match with non-blocking block-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading3,
atol=1e-5,
rtol=0,
msg="Output should match with leaf-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading4,
atol=1e-5,
rtol=0,
msg="Output should match with leaf-level offloading with stream",
)
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
_ = model(**inputs_dict)[0]
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
storage_dtype, compute_dtype = torch.float16, torch.float32
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
model.enable_group_offload(
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
@torch.inference_mode()
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params
def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
tmpdir = str(tmp_path)
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading,
atol=atol,
rtol=0,
msg="Output should match with disk-based group offloading",
)
class LayerwiseCastingTesterMixin:
"""
Mixin class for testing layerwise dtype casting for memory optimization.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
@torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats():
gc.collect()
backend_synchronize(torch_device)
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
torch.float8_e4m3fn, torch.bfloat16
)
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
"Memory footprint should decrease with lower precision storage"
)
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
"Peak memory should be lower with bf16 compute on newer GPUs"
)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
assert (
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
), "Peak memory should be lower or within tolerance with fp8 storage"
def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
model = self.model_class(**self.get_init_dict())
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()
inputs_dict = self.get_inputs_dict()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict, return_dict=False)[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
test_fn(torch.float16, torch.float32)
test_fn(torch.float8_e4m3fn, torch.float32)
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
@is_memory
@require_accelerator
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
"""
Combined mixin class for all memory optimization tests including CPU/disk offloading,
group offloading, and layerwise dtype casting.
This mixin inherits from:
- CPUOffloadTesterMixin: CPU and disk offloading tests
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: memory
Use `pytest -m "not memory"` to skip these tests
"""
pass

View File

@@ -0,0 +1,99 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
@torch.no_grad()
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
result_queue.put(("success", output.shape))
except Exception as e:
if rank == 0:
result_queue.put(("error", str(e)))
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
cp_dict = {cp_type: world_size}
ctx = mp.get_context("spawn")
result_queue = ctx.Queue()
mp.spawn(
_context_parallel_worker,
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
nprocs=world_size,
join=True,
)
status, result = result_queue.get(timeout=60)
assert status == "success", f"Context parallel inference failed: {result}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,244 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_single_file,
nightly,
require_torch_accelerator,
torch_device,
)
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
"""Download a single file checkpoint from the Hub to a temporary directory."""
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
return path
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
"""Download diffusers config files (excluding weights) from a repository."""
path = snapshot_download(
pretrained_model_name_or_path,
ignore_patterns=[
"**/*.ckpt",
"*.ckpt",
"**/*.bin",
"*.bin",
"**/*.pt",
"*.pt",
"**/*.safetensors",
"*.safetensors",
],
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
local_dir=tmpdir,
)
return path
@nightly
@require_torch_accelerator
@is_single_file
class SingleFileTesterMixin:
"""
Mixin class for testing single file loading for models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- ckpt_path: Path or Hub path to the single file checkpoint
- subfolder: (Optional) Subfolder within the repo
- torch_dtype: (Optional) torch dtype to use for testing
Pytest mark: single_file
Use `pytest -m "not single_file"` to skip these tests
"""
pretrained_model_name_or_path = None
ckpt_path = None
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_model_config(self):
pretrained_kwargs = {}
single_file_kwargs = {}
pretrained_kwargs["device"] = torch_device
single_file_kwargs["device"] = torch_device
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading: "
f"pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_model_parameters(self):
pretrained_kwargs = {}
single_file_kwargs = {}
pretrained_kwargs["device"] = torch_device
single_file_kwargs["device"] = torch_device
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
state_dict = model.state_dict()
state_dict_single_file = model_single_file.state_dict()
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
"Model parameters keys differ between pretrained and single file loading. "
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
)
for key in state_dict.keys():
param = state_dict[key]
param_single_file = state_dict_single_file[key]
assert param.shape == param_single_file.shape, (
f"Parameter shape mismatch for {key}: "
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)
assert_tensors_close(
param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}"
)
def test_single_file_loading_local_files_only(self, tmp_path):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with local_files_only=True"
def test_single_file_loading_with_diffusers_config(self):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
# Load with config parameter
model_single_file = self.model_class.from_single_file(
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
)
# Load pretrained for comparison
pretrained_kwargs = {}
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
# Compare configs
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
def test_single_file_loading_dtype(self):
for dtype in [torch.float32, torch.float16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
# Cleanup
del model_single_file
gc.collect()
backend_empty_cache(torch_device)
def test_checkpoint_variant_loading(self):
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
return
for ckpt_path in self.alternate_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
del model
gc.collect()
backend_empty_cache(torch_device)

View File

@@ -0,0 +1,207 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pytest
import torch
from diffusers.training_utils import EMAModel
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
@is_training
@require_torch_accelerator_with_training
class TrainingTesterMixin:
"""
Mixin class for testing training functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Expected properties to be implemented by subclasses:
- output_shape: Tuple defining the expected output shape
Pytest mark: training
Use `pytest -m "not training"` to skip these tests
"""
def test_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_training_with_ema(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
ema_model = EMAModel(model.parameters())
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model.parameters())
def test_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
init_dict = self.get_init_dict()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
# check enable works
model.enable_gradient_checkpointing()
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
# check disable works
model.disable_gradient_checkpointing()
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
def test_gradient_checkpointing_is_applied(self, expected_set=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if expected_set is None:
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
init_dict = self.get_init_dict()
model_class_copy = copy.copy(self.model_class)
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set, (
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
)
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if skip is None:
skip = set()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict, return_dict=False)[0]
# run the backwards pass on the model
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
# run the backwards pass on the model
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
assert (loss - loss_2).abs() < loss_tolerance, (
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
if name in skip:
continue
if param.grad is None:
continue
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
f"Gradient mismatch for {name}"
)
def test_mixed_precision_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
# Test with float16
if torch.device(torch_device).type != "cpu":
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
# Test with bfloat16
if torch.device(torch_device).type != "cpu":
model.zero_grad()
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()

View File

@@ -13,23 +13,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Any
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
# TODO: This standalone function maintains backward compatibility with pipeline tests
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
ip_cross_attn_state_dict = {}
key_id = 0
@@ -39,7 +67,7 @@ def create_flux_ip_adapter_state_dict(model):
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
sd = FluxIPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
@@ -50,11 +78,8 @@ def create_flux_ip_adapter_state_dict(model):
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
@@ -75,57 +100,37 @@ def create_flux_ip_adapter_state_dict(model):
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
class FluxTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return FluxTransformer2DModel
@property
def dummy_input(self):
return self.prepare_dummy_input()
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-flux-pipe"
@property
def input_shape(self):
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, int]:
return (16, 4)
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
"""Return Flux model initialization arguments."""
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -137,11 +142,32 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
"axes_dims_rope": [4, 4, 8],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator
),
"pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator),
"img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator),
"txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
"""Test that deprecated 3D img_ids and txt_ids still work."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
@@ -162,63 +188,267 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# The test exists for cases like
# https://github.com/huggingface/diffusers/issues/11874
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_exclude_modules(self):
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
lora_rank = 4
target_module = "single_transformer_blocks.0.proj_out"
adapter_name = "foo"
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
pass
state_dict = model.state_dict()
target_mod_shape = state_dict[f"{target_module}.weight"].shape
lora_state_dict = {
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
pass
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
pass
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""
pass
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
ip_adapter_processor_cls = FluxIPAdapterAttnProcessor
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
torch.manual_seed(0)
# Create dummy image embeds for IP adapter
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
return inputs_dict
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
return create_flux_ip_adapter_state_dict(model)
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
pass
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
config = LoraConfig(
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
)
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
subfolder = "transformer"
pass
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
pass
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""
pass
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
"""FasterCache tests for Flux Transformer."""
# Flux is guidance distilled, so we can test at model level without CFG batch handling
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
"is_guidance_distilled": True,
}

View File

@@ -38,6 +38,7 @@ from diffusers.utils.import_utils import (
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -130,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs):
return True
def assert_tensors_close(
actual: "torch.Tensor",
expected: "torch.Tensor",
atol: float = 1e-5,
rtol: float = 1e-5,
msg: str = "",
) -> None:
"""
Assert that two tensors are close within tolerance.
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
Provides concise, actionable error messages without dumping full tensors.
Args:
actual: The actual tensor from the computation.
expected: The expected tensor to compare against.
atol: Absolute tolerance.
rtol: Relative tolerance.
msg: Optional message prefix for the assertion error.
Raises:
AssertionError: If tensors have different shapes or values exceed tolerance.
Example:
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
"""
if not is_torch_available():
raise ValueError("PyTorch needs to be installed to use this function.")
if actual.shape != expected.shape:
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
abs_diff = (actual - expected).abs()
max_diff = abs_diff.max().item()
flat_idx = abs_diff.argmax().item()
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
threshold = atol + rtol * expected.abs()
mismatched = (abs_diff > threshold).sum().item()
total = actual.numel()
raise AssertionError(
f"{msg}\n"
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
)
def numpy_cosine_similarity_distance(a, b):
similarity = np.dot(a, b) / (norm(a) * norm(b))
distance = 1.0 - similarity.mean()
@@ -241,7 +295,6 @@ def parse_flag_from_env(key, default=False):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
@@ -282,12 +335,155 @@ def nightly(test_case):
def is_torch_compile(test_case):
"""
Decorator marking a test that runs compile tests in the diffusers CI.
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
Decorator marking a test as a torch.compile test. These tests can be filtered using:
pytest -m "not compile" to skip
pytest -m compile to run only these tests
"""
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
return pytest.mark.compile(test_case)
def is_single_file(test_case):
"""
Decorator marking a test as a single file loading test. These tests can be filtered using:
pytest -m "not single_file" to skip
pytest -m single_file to run only these tests
"""
return pytest.mark.single_file(test_case)
def is_lora(test_case):
"""
Decorator marking a test as a LoRA test. These tests can be filtered using:
pytest -m "not lora" to skip
pytest -m lora to run only these tests
"""
return pytest.mark.lora(test_case)
def is_ip_adapter(test_case):
"""
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
pytest -m "not ip_adapter" to skip
pytest -m ip_adapter to run only these tests
"""
return pytest.mark.ip_adapter(test_case)
def is_training(test_case):
"""
Decorator marking a test as a training test. These tests can be filtered using:
pytest -m "not training" to skip
pytest -m training to run only these tests
"""
return pytest.mark.training(test_case)
def is_attention(test_case):
"""
Decorator marking a test as an attention test. These tests can be filtered using:
pytest -m "not attention" to skip
pytest -m attention to run only these tests
"""
return pytest.mark.attention(test_case)
def is_memory(test_case):
"""
Decorator marking a test as a memory optimization test. These tests can be filtered using:
pytest -m "not memory" to skip
pytest -m memory to run only these tests
"""
return pytest.mark.memory(test_case)
def is_cpu_offload(test_case):
"""
Decorator marking a test as a CPU offload test. These tests can be filtered using:
pytest -m "not cpu_offload" to skip
pytest -m cpu_offload to run only these tests
"""
return pytest.mark.cpu_offload(test_case)
def is_group_offload(test_case):
"""
Decorator marking a test as a group offload test. These tests can be filtered using:
pytest -m "not group_offload" to skip
pytest -m group_offload to run only these tests
"""
return pytest.mark.group_offload(test_case)
def is_quantization(test_case):
"""
Decorator marking a test as a quantization test. These tests can be filtered using:
pytest -m "not quantization" to skip
pytest -m quantization to run only these tests
"""
return pytest.mark.quantization(test_case)
def is_bitsandbytes(test_case):
"""
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
pytest -m "not bitsandbytes" to skip
pytest -m bitsandbytes to run only these tests
"""
return pytest.mark.bitsandbytes(test_case)
def is_quanto(test_case):
"""
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
pytest -m "not quanto" to skip
pytest -m quanto to run only these tests
"""
return pytest.mark.quanto(test_case)
def is_torchao(test_case):
"""
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
pytest -m "not torchao" to skip
pytest -m torchao to run only these tests
"""
return pytest.mark.torchao(test_case)
def is_gguf(test_case):
"""
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
pytest -m "not gguf" to skip
pytest -m gguf to run only these tests
"""
return pytest.mark.gguf(test_case)
def is_modelopt(test_case):
"""
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
pytest -m "not modelopt" to skip
pytest -m modelopt to run only these tests
"""
return pytest.mark.modelopt(test_case)
def is_context_parallel(test_case):
"""
Decorator marking a test as a context parallel inference test. These tests can be filtered using:
pytest -m "not context_parallel" to skip
pytest -m context_parallel to run only these tests
"""
return pytest.mark.context_parallel(test_case)
def is_cache(test_case):
"""
Decorator marking a test as a cache test. These tests can be filtered using:
pytest -m "not cache" to skip
pytest -m cache to run only these tests
"""
return pytest.mark.cache(test_case)
def require_torch(test_case):
@@ -650,6 +846,19 @@ def require_kernels_version_greater_or_equal(kernels_version):
return decorator
def require_modelopt_version_greater_or_equal(modelopt_version):
def decorator(test_case):
correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
version.parse(importlib.metadata.version("modelopt")).base_version
) >= version.parse(modelopt_version)
return pytest.mark.skipif(
not correct_nvidia_modelopt_version,
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend

View File

@@ -0,0 +1,567 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility script to generate test suites for diffusers model classes.
Usage:
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py
This will analyze the model file and generate a test file with appropriate
test classes based on the model's mixins and attributes.
"""
import argparse
import ast
import sys
from pathlib import Path
MIXIN_TO_TESTER = {
"ModelMixin": "ModelTesterMixin",
"PeftAdapterMixin": "LoraTesterMixin",
}
ATTRIBUTE_TO_TESTER = {
"_cp_plan": "ContextParallelTesterMixin",
"_supports_gradient_checkpointing": "TrainingTesterMixin",
}
ALWAYS_INCLUDE_TESTERS = [
"ModelTesterMixin",
"MemoryTesterMixin",
"TorchCompileTesterMixin",
]
# Attention-related class names that indicate the model uses attention
ATTENTION_INDICATORS = {
"AttentionMixin",
"AttentionModuleMixin",
}
OPTIONAL_TESTERS = [
# Quantization testers
("BitsAndBytesTesterMixin", "bnb"),
("QuantoTesterMixin", "quanto"),
("TorchAoTesterMixin", "torchao"),
("GGUFTesterMixin", "gguf"),
("ModelOptTesterMixin", "modelopt"),
# Quantization compile testers
("BitsAndBytesCompileTesterMixin", "bnb_compile"),
("QuantoCompileTesterMixin", "quanto_compile"),
("TorchAoCompileTesterMixin", "torchao_compile"),
("GGUFCompileTesterMixin", "gguf_compile"),
("ModelOptCompileTesterMixin", "modelopt_compile"),
# Cache testers
("PyramidAttentionBroadcastTesterMixin", "pab_cache"),
("FirstBlockCacheTesterMixin", "fbc_cache"),
("FasterCacheTesterMixin", "faster_cache"),
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
]
class ModelAnalyzer(ast.NodeVisitor):
def __init__(self):
self.model_classes = []
self.current_class = None
self.imports = set()
def visit_Import(self, node: ast.Import):
for alias in node.names:
self.imports.add(alias.name.split(".")[-1])
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom):
for alias in node.names:
self.imports.add(alias.name)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef):
base_names = []
for base in node.bases:
if isinstance(base, ast.Name):
base_names.append(base.id)
elif isinstance(base, ast.Attribute):
base_names.append(base.attr)
if "ModelMixin" in base_names:
class_info = {
"name": node.name,
"bases": base_names,
"attributes": {},
"has_forward": False,
"init_params": [],
}
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name):
attr_name = target.id
if attr_name.startswith("_"):
class_info["attributes"][attr_name] = self._get_value(item.value)
elif isinstance(item, ast.FunctionDef):
if item.name == "forward":
class_info["has_forward"] = True
class_info["forward_params"] = self._extract_func_params(item)
elif item.name == "__init__":
class_info["init_params"] = self._extract_func_params(item)
self.model_classes.append(class_info)
self.generic_visit(node)
def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]:
params = []
args = func_node.args
num_defaults = len(args.defaults)
num_args = len(args.args)
first_default_idx = num_args - num_defaults
for i, arg in enumerate(args.args):
if arg.arg == "self":
continue
param_info = {"name": arg.arg, "type": None, "default": None}
if arg.annotation:
param_info["type"] = self._get_annotation_str(arg.annotation)
default_idx = i - first_default_idx
if default_idx >= 0 and default_idx < len(args.defaults):
param_info["default"] = self._get_value(args.defaults[default_idx])
params.append(param_info)
return params
def _get_annotation_str(self, node) -> str:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Constant):
return repr(node.value)
elif isinstance(node, ast.Subscript):
base = self._get_annotation_str(node.value)
if isinstance(node.slice, ast.Tuple):
args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts)
else:
args = self._get_annotation_str(node.slice)
return f"{base}[{args}]"
elif isinstance(node, ast.Attribute):
return f"{self._get_annotation_str(node.value)}.{node.attr}"
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
left = self._get_annotation_str(node.left)
right = self._get_annotation_str(node.right)
return f"{left} | {right}"
elif isinstance(node, ast.Tuple):
return ", ".join(self._get_annotation_str(el) for el in node.elts)
return "Any"
def _get_value(self, node):
if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.Name):
if node.id == "None":
return None
elif node.id == "True":
return True
elif node.id == "False":
return False
return node.id
elif isinstance(node, ast.List):
return [self._get_value(el) for el in node.elts]
elif isinstance(node, ast.Dict):
return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)}
return "<complex>"
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
with open(filepath) as f:
source = f.read()
tree = ast.parse(source)
analyzer = ModelAnalyzer()
analyzer.visit(tree)
return analyzer.model_classes, analyzer.imports
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
testers = list(ALWAYS_INCLUDE_TESTERS)
for base in model_info["bases"]:
if base in MIXIN_TO_TESTER:
tester = MIXIN_TO_TESTER[base]
if tester not in testers:
testers.append(tester)
for attr, tester in ATTRIBUTE_TO_TESTER.items():
if attr in model_info["attributes"]:
value = model_info["attributes"][attr]
if value is not None and value is not False:
if tester not in testers:
testers.append(tester)
if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None:
if "ContextParallelTesterMixin" not in testers:
testers.append("ContextParallelTesterMixin")
# Include AttentionTesterMixin if the model imports attention-related classes
if imports & ATTENTION_INDICATORS:
testers.append("AttentionTesterMixin")
for tester, flag in OPTIONAL_TESTERS:
if flag in include_optional:
if tester not in testers:
testers.append(tester)
return testers
def generate_config_class(model_info: dict, model_name: str) -> str:
class_name = f"{model_name}TesterConfig"
model_class = model_info["name"]
forward_params = model_info.get("forward_params", [])
init_params = model_info.get("init_params", [])
lines = [
f"class {class_name}:",
f" model_class = {model_class}",
' pretrained_model_name_or_path = ""',
' pretrained_model_kwargs = {"subfolder": "transformer"}',
"",
" @property",
" def generator(self):",
' return torch.Generator("cpu").manual_seed(0)',
"",
" def get_init_dict(self) -> dict[str, int | list[int]]:",
]
if init_params:
lines.append(" # __init__ parameters:")
for param in init_params:
type_str = f": {param['type']}" if param["type"] else ""
default_str = f" = {param['default']}" if param["default"] is not None else ""
lines.append(f" # {param['name']}{type_str}{default_str}")
lines.extend(
[
" return {}",
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
]
)
if forward_params:
lines.append(" # forward() parameters:")
for param in forward_params:
type_str = f": {param['type']}" if param["type"] else ""
default_str = f" = {param['default']}" if param["default"] is not None else ""
lines.append(f" # {param['name']}{type_str}{default_str}")
lines.extend(
[
" # TODO: Fill in dummy inputs",
" return {}",
"",
" @property",
" def input_shape(self) -> tuple[int, ...]:",
" return (1, 1)",
"",
" @property",
" def output_shape(self) -> tuple[int, ...]:",
" return (1, 1)",
]
)
return "\n".join(lines)
def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
tester_short = tester.replace("TesterMixin", "")
class_name = f"Test{model_name}{tester_short}"
lines = [f"class {class_name}({config_class}, {tester}):"]
if tester == "TorchCompileTesterMixin":
lines.extend(
[
" different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]",
"",
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
" # TODO: Implement dynamic input generation",
" return {}",
]
)
elif tester == "IPAdapterTesterMixin":
lines.extend(
[
" ip_adapter_processor_cls = None # TODO: Set processor class",
"",
" def modify_inputs_for_ip_adapter(self, model, inputs_dict):",
" # TODO: Add IP adapter image embeds to inputs",
" return inputs_dict",
"",
" def create_ip_adapter_state_dict(self, model):",
" # TODO: Create IP adapter state dict",
" return {}",
]
)
elif tester == "SingleFileTesterMixin":
lines.extend(
[
' ckpt_path = "" # TODO: Set checkpoint path',
" alternate_keys_ckpt_paths = []",
' pretrained_model_name_or_path = ""',
' subfolder = "transformer"',
]
)
elif tester == "GGUFTesterMixin":
lines.extend(
[
' gguf_filename = "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization tests",
" return {}",
]
)
elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization tests",
" return {}",
]
)
elif tester in [
"BitsAndBytesCompileTesterMixin",
"QuantoCompileTesterMixin",
"TorchAoCompileTesterMixin",
"ModelOptCompileTesterMixin",
]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester == "GGUFCompileTesterMixin":
lines.extend(
[
' gguf_filename = "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester in [
"PyramidAttentionBroadcastTesterMixin",
"FirstBlockCacheTesterMixin",
"FasterCacheTesterMixin",
]:
lines.append(" pass")
elif tester == "LoraHotSwappingForModelTesterMixin":
lines.extend(
[
" different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]",
"",
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
" # TODO: Implement dynamic input generation",
" return {}",
]
)
else:
lines.append(" pass")
return "\n".join(lines)
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
testers = determine_testers(model_info, include_optional, imports)
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
lines = [
"# coding=utf-8",
"# Copyright 2025 HuggingFace Inc.",
"#",
'# Licensed under the Apache License, Version 2.0 (the "License");',
"# you may not use this file except in compliance with the License.",
"# You may obtain a copy of the License at",
"#",
"# http://www.apache.org/licenses/LICENSE-2.0",
"#",
"# Unless required by applicable law or agreed to in writing, software",
'# distributed under the License is distributed on an "AS IS" BASIS,',
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"# See the License for the specific language governing permissions and",
"# limitations under the License.",
"",
"import torch",
"",
f"from diffusers import {model_info['name']}",
"from diffusers.utils.torch_utils import randn_tensor",
"",
"from ...testing_utils import enable_full_determinism, torch_device",
]
if "LoraTesterMixin" in testers:
lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
lines.extend(
[
"from ..testing_utils import (",
*[f" {tester}," for tester in sorted(tester_imports)],
")",
"",
"",
"enable_full_determinism()",
"",
"",
]
)
config_class = f"{model_name}TesterConfig"
lines.append(generate_config_class(model_info, model_name))
lines.append("")
lines.append("")
for tester in testers:
lines.append(generate_test_class(model_name, config_class, tester))
lines.append("")
lines.append("")
if "LoraTesterMixin" in testers:
lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
lines.append("")
lines.append("")
return "\n".join(lines).rstrip() + "\n"
def get_test_output_path(model_filepath: str) -> str:
path = Path(model_filepath)
model_filename = path.stem
if "transformers" in path.parts:
return f"tests/models/transformers/test_models_{model_filename}.py"
elif "unets" in path.parts:
return f"tests/models/unets/test_models_{model_filename}.py"
elif "autoencoders" in path.parts:
return f"tests/models/autoencoders/test_models_{model_filename}.py"
else:
return f"tests/models/test_models_{model_filename}.py"
def main():
parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class")
parser.add_argument(
"model_filepath",
type=str,
help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)",
)
parser.add_argument(
"--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)"
)
parser.add_argument(
"--include",
"-i",
type=str,
nargs="*",
default=[],
choices=[
"bnb",
"quanto",
"torchao",
"gguf",
"modelopt",
"bnb_compile",
"quanto_compile",
"torchao_compile",
"gguf_compile",
"modelopt_compile",
"pab_cache",
"fbc_cache",
"faster_cache",
"single_file",
"ip_adapter",
"all",
],
help="Optional testers to include",
)
parser.add_argument(
"--class-name",
"-c",
type=str,
default=None,
help="Specific model class to generate tests for (default: first model class found)",
)
parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file")
args = parser.parse_args()
if not Path(args.model_filepath).exists():
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
sys.exit(1)
model_classes, imports = analyze_model_file(args.model_filepath)
if not model_classes:
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
sys.exit(1)
if args.class_name:
model_info = next((m for m in model_classes if m["name"] == args.class_name), None)
if not model_info:
available = [m["name"] for m in model_classes]
print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr)
sys.exit(1)
else:
model_info = model_classes[0]
if len(model_classes) > 1:
print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr)
print("Use --class-name to specify a different class", file=sys.stderr)
include_optional = args.include
if "all" in include_optional:
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
if args.dry_run:
print(generated_code)
else:
output_path = args.output or get_test_output_path(args.model_filepath)
output_dir = Path(output_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
f.write(generated_code)
print(f"Generated test file: {output_path}")
print(f"Model class: {model_info['name']}")
print(f"Detected attributes: {list(model_info['attributes'].keys())}")
if __name__ == "__main__":
main()