HunyuanImage21 (#12333)

* add hunyuanimage2.1


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
YiYi Xu
2025-10-23 22:31:12 -10:00
committed by GitHub
parent bc4039886d
commit a138d71ec1
40 changed files with 6656 additions and 224 deletions

View File

@@ -347,6 +347,8 @@
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuanimage_transformer_2d
title: HunyuanImageTransformer2DModel
- local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
@@ -411,6 +413,10 @@
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_cosmos
title: AutoencoderKLCosmos
- local: api/models/autoencoder_kl_hunyuanimage
title: AutoencoderKLHunyuanImage
- local: api/models/autoencoder_kl_hunyuanimage_refiner
title: AutoencoderKLHunyuanImageRefiner
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
@@ -620,6 +626,8 @@
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl

View File

@@ -0,0 +1,32 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImage
The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImage
vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImage
[[autodoc]] AutoencoderKLHunyuanImage
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

View File

@@ -0,0 +1,32 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImageRefiner
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImageRefiner
vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImageRefiner
[[autodoc]] AutoencoderKLHunyuanImageRefiner
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

View File

@@ -0,0 +1,30 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HunyuanImageTransformer2DModel
A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanImageTransformer2DModel
transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HunyuanImageTransformer2DModel
[[autodoc]] HunyuanImageTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

View File

@@ -0,0 +1,152 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# HunyuanImage2.1
HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
HunyuanImage-2.1 comes in the following variants:
| model type | model id |
|:----------:|:--------:|
| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
> [!TIP]
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## HunyuanImage-2.1
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
```python
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
```
You can inspect the `guider` object:
```py
>>> pipe.guider
AdaptiveProjectedMixGuidance {
"_class_name": "AdaptiveProjectedMixGuidance",
"_diffusers_version": "0.36.0.dev0",
"adaptive_projected_guidance_momentum": -0.5,
"adaptive_projected_guidance_rescale": 10.0,
"adaptive_projected_guidance_scale": 10.0,
"adaptive_projected_guidance_start_step": 5,
"enabled": true,
"eta": 0.0,
"guidance_rescale": 0.0,
"guidance_scale": 3.5,
"start": 0.0,
"stop": 1.0,
"use_original_formulation": false
}
State:
step: None
num_inference_steps: None
timestep: None
count_prepared: 0
enabled: True
num_conditions: 2
momentum_buffer: None
is_apg_enabled: False
is_cfg_enabled: True
```
To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
# Update the guider configuration
pipe.guider = pipe.guider.new(guidance_scale=5.0)
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
image = pipe(
prompt=prompt,
num_inference_steps=50,
height=2048,
width=2048,
).images[0]
image.save("image.png")
```
## HunyuanImage-2.1-Distilled
use `distilled_guidance_scale` with the guidance-distilled checkpoint,
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
out = pipe(
prompt,
num_inference_steps=8,
distilled_guidance_scale=3.25,
height=2048,
width=2048,
generator=generator,
).images[0]
```
## HunyuanImagePipeline
[[autodoc]] HunyuanImagePipeline
- all
- __call__
## HunyuanImageRefinerPipeline
[[autodoc]] HunyuanImageRefinerPipeline
- all
- __call__
## HunyuanImagePipelineOutput
[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput

File diff suppressed because it is too large Load Diff

View File

@@ -149,7 +149,9 @@ else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"AdaptiveProjectedMixGuidance",
"AutoGuidance",
"BaseGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"FrequencyDecoupledGuidance",
@@ -184,6 +186,8 @@ else:
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
@@ -216,6 +220,7 @@ else:
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
@@ -462,6 +467,8 @@ else:
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanImagePipeline",
"HunyuanImageRefinerPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline",
@@ -849,7 +856,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .guiders import (
AdaptiveProjectedGuidance,
AdaptiveProjectedMixGuidance,
AutoGuidance,
BaseGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
@@ -880,6 +889,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -912,6 +923,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
@@ -1128,6 +1140,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanImagePipeline,
HunyuanImageRefinerPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,

View File

@@ -14,28 +14,24 @@
from typing import Union
from ..utils import is_torch_available
from ..utils import is_torch_available, logging
logger = logging.get_logger(__name__)
logger.warning(
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
]

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -65,8 +65,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
@@ -76,19 +77,14 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
@@ -152,6 +148,44 @@ class MomentumBuffer:
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def normalized_guidance(
pred_cond: torch.Tensor,

View File

@@ -0,0 +1,284 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedMixGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
used, and momentum buffer is updated).
enabled (`bool`, defaults to `True`):
Whether this guidance is enabled.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 3.5,
guidance_rescale: float = 0.0,
adaptive_projected_guidance_scale: float = 10.0,
adaptive_projected_guidance_momentum: float = -0.5,
adaptive_projected_guidance_rescale: float = 10.0,
eta: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
adaptive_projected_guidance_start_step: int = 5,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
# no guidance
if not self._is_cfg_enabled():
pred = pred_cond
# CFG + update momentum buffer
elif not self._is_apg_enabled():
if self.momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
# CFG + update momentum buffer
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
# APG
elif self._is_apg_enabled():
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.adaptive_projected_guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled() or self._is_cfg_enabled():
num_conditions += 1
return num_conditions
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
if not self._is_cfg_enabled():
return False
is_within_range = False
if self._step is not None:
is_within_range = self._step > self.adaptive_projected_guidance_start_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
else:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
return is_within_range and not is_close
def get_state(self):
state = super().get_state()
state["momentum_buffer"] = self.momentum_buffer
state["is_apg_enabled"] = self._is_apg_enabled()
state["is_cfg_enabled"] = self._is_cfg_enabled()
return state
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def update_momentum_buffer(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
momentum_buffer: Optional[MomentumBuffer] = None,
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
if momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
diff = momentum_buffer.running_average
else:
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred

View File

@@ -72,8 +72,9 @@ class AutoGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
@@ -132,16 +133,11 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -27,43 +27,50 @@ if TYPE_CHECKING:
class ClassifierFreeGuidance(BaseGuidance):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
Implements Classifier-Free Guidance (CFG) for diffusion models.
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
proposes scaling and shifting the conditional distribution based on the difference between conditional and
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Reference: https://huggingface.co/papers/2207.12598
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
CFG improves generation quality and prompt adherence by jointly training models on both conditional and
unconditional data, then combining predictions during inference. This allows trading off between quality (high
guidance) and diversity (low guidance).
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
**Two CFG Formulations:**
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
1. **Original formulation** (from paper):
```
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
```
Moves conditional predictions further from unconditional ones.
2. **Diffusers-native formulation** (default, from Imagen paper):
```
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
```
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
quality", "watermarks"). Equivalent in theory but more intuitive.
Use `use_original_formulation=True` to switch to the original formulation.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
may reduce quality. Typical range: 1.0-20.0.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
to 1.0 (full rescaling).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
diffusers-native formulation from the Imagen paper.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
steps.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
steps.
enabled (`bool`, defaults to `True`):
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@@ -76,23 +83,19 @@ class ClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -68,31 +68,31 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if self._step < self.zero_init_steps:
# YiYi Notes: add default behavior for self._enabled == False
if not self._enabled:
pred = pred_cond
elif self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond

View File

@@ -149,6 +149,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data",
upcast_to_double: bool = True,
enabled: bool = True,
):
if not _CAN_USE_KORNIA:
raise ImportError(
@@ -160,7 +161,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
# Set start to earliest start for any freq component and stop to latest stop for any freq component
min_start = start if isinstance(start, float) else min(start)
max_stop = stop if isinstance(stop, float) else max(stop)
super().__init__(min_start, max_stop)
super().__init__(min_start, max_stop, enabled)
self.guidance_scales = guidance_scales
self.levels = len(guidance_scales)
@@ -217,16 +218,11 @@ class FrequencyDecoupledGuidance(BaseGuidance):
f"({len(self.guidance_scales)})"
)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

View File

@@ -40,7 +40,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
_input_predictions = None
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0):
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
self._start = start
self._stop = stop
self._step: int = None
@@ -48,7 +48,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = True
self._enabled = enabled
if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
@@ -60,6 +60,31 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def new(self, **kwargs):
"""
Creates a copy of this guider instance, optionally with modified configuration parameters.
Args:
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
returns an exact copy with the same configuration.
Returns:
A new guider instance with the same (or updated) configuration.
Example:
```python
# Create a CFG guider
guider = ClassifierFreeGuidance(guidance_scale=3.5)
# Create an exact copy
same_guider = guider.new()
# Create a copy with different start step, keeping other config the same
new_guider = guider.new(guidance_scale=5)
```
"""
return self.__class__.from_config(self.config, **kwargs)
def disable(self):
self._enabled = False
@@ -72,42 +97,52 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep = timestep
self._count_prepared = 0
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
def get_state(self) -> Dict[str, Any]:
"""
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
the values of the provided keyword arguments to this method.
Args:
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with a
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
conditional data identifier and the second element must be the unconditional data identifier or None.
Example:
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields(
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
the __repr__ method. Returns:
`Dict[str, Any]`: A dictionary containing the current state variables including:
- step: Current inference step
- num_inference_steps: Total number of inference steps
- timestep: Current timestep tensor
- count_prepared: Number of times prepare_models has been called
- enabled: Whether the guidance is enabled
- num_conditions: Number of conditions
"""
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = (
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
state = {
"step": self._step,
"num_inference_steps": self._num_inference_steps,
"timestep": self._timestep,
"count_prepared": self._count_prepared,
"enabled": self._enabled,
"num_conditions": self.num_conditions,
}
return state
def __repr__(self) -> str:
"""
Returns a string representation of the guidance object including both config and current state.
"""
# Get ConfigMixin's __repr__
str_repr = super().__repr__()
# Get current state
state = self.get_state()
# Format each state variable on its own line with indentation
state_lines = []
for k, v in state.items():
# Convert value to string and handle multi-line values
v_str = str(v)
if "\n" in v_str:
# For multi-line values (like MomentumBuffer), indent subsequent lines
v_lines = v_str.split("\n")
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
state_lines.append(f" {k}: {v_str}")
state_str = "\n".join(state_lines)
return f"{str_repr}\nState:\n{state_str}"
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
@@ -155,8 +190,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
@classmethod
def _prepare_batch(
cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
tuple_index: int,
identifier: str,
) -> "BlockState":
@@ -182,21 +216,16 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"""
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError(
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
)
data_batch = {}
for key, value in input_fields.items():
for key, value in data.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
if isinstance(value, torch.Tensor):
data_batch[key] = value
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
data_batch[key] = value[tuple_index]
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
raise ValueError(f"Invalid value type: {type(value)}")
except ValueError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)

View File

@@ -98,8 +98,9 @@ class PerturbedAttentionGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
@@ -168,12 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -186,8 +182,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

View File

@@ -100,8 +100,9 @@ class SkipLayerGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
@@ -164,12 +165,7 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -182,8 +178,8 @@ class SkipLayerGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

View File

@@ -92,8 +92,9 @@ class SmoothedEnergyGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
@@ -153,12 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -171,8 +167,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -58,23 +58,19 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop)
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

View File

@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
@@ -149,6 +150,14 @@ def _register_attention_processors_metadata():
),
)
# HunyuanImageAttnProcessor
AttentionProcessorRegistry.register(
model_class=HunyuanImageAttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -162,6 +171,10 @@ def _register_transformer_blocks_metadata():
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_hunyuanimage import (
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
@@ -283,6 +296,22 @@ def _register_transformer_blocks_metadata():
),
)
# HunyuanImage2.1
TransformerBlockRegistry.register(
model_class=HunyuanImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanImageSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
@@ -308,4 +337,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on

View File

@@ -36,6 +36,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
@@ -91,6 +93,7 @@ if is_torch_available():
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
@@ -133,6 +136,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -182,6 +187,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,

View File

@@ -5,6 +5,8 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi

View File

@@ -0,0 +1,709 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageResnetBlock(nn.Module):
r"""
Residual block with two convolutions and optional channel change.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(self, x):
# Apply shortcut connection
residual = x
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
# Add residual connection
return x + residual
class HunyuanImageAttentionBlock(nn.Module):
r"""
Self-attention with a single head.
Args:
in_channels (int): The number of channels in the input tensor.
"""
def __init__(self, in_channels: int):
super().__init__()
# layers
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.to_q = nn.Conv2d(in_channels, in_channels, 1)
self.to_k = nn.Conv2d(in_channels, in_channels, 1)
self.to_v = nn.Conv2d(in_channels, in_channels, 1)
self.proj = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
identity = x
x = self.norm(x)
# compute query, key, value
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, height, width = query.shape
query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
# apply attention
x = F.scaled_dot_product_attention(query, key, value)
x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
# output projection
x = self.proj(x)
return x + identity
class HunyuanImageDownsample(nn.Module):
"""
Downsampling block for spatial reduction.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
if out_channels % factor != 0:
raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.group_size = factor * in_channels // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, C, H // 2, 2, W // 2, 2)
h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
h = h.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = x.shape
shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageUpsample(nn.Module):
"""
Upsampling block for spatial expansion.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
self.repeats = factor * out_channels // in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
h = h.reshape(B, C // 4, H * 2, W * 2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
B, C, H, W = shortcut.shape
shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
return h + shortcut
class HunyuanImageMidBlock(nn.Module):
"""
Middle block for HunyuanImageVAE encoder and decoder.
Args:
in_channels (int): Number of input channels.
num_layers (int): Number of layers.
"""
def __init__(self, in_channels: int, num_layers: int = 1):
super().__init__()
resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
attentions = []
for _ in range(num_layers):
attentions.append(HunyuanImageAttentionBlock(in_channels))
resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnets[0](x)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
x = attn(x)
x = resnet(x)
return x
class HunyuanImageEncoder2D(nn.Module):
r"""
Encoder network that compresses input to latent representation.
Args:
in_channels (int): Number of input channels.
z_channels (int): Number of latent channels.
block_out_channels (list of int): Output channels for each block.
num_res_blocks (int): Number of residual blocks per block.
spatial_compression_ratio (int): Spatial downsampling factor.
non_linearity (str): Type of non-linearity to use. Default is "silu".
downsample_match_channel (bool): Whether to match channels during downsampling.
"""
def __init__(
self,
in_channels: int,
z_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
non_linearity: str = "silu",
downsample_match_channel: bool = True,
):
super().__init__()
if block_out_channels[-1] % (2 * z_channels) != 0:
raise ValueError(
f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
)
self.in_channels = in_channels
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.spatial_compression_ratio = spatial_compression_ratio
self.group_size = block_out_channels[-1] // (2 * z_channels)
self.nonlinearity = get_activation(non_linearity)
# init block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
block_in_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
# residual blocks
for _ in range(num_res_blocks):
self.down_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# downsample block
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if downsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.down_blocks.append(
HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# middle blocks
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
# output blocks
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_in(x)
## downsamples
for down_block in self.down_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(down_block, x)
else:
x = down_block(x)
## middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(self.mid_block, x)
else:
x = self.mid_block(x)
## head
B, C, H, W = x.shape
residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
x = self.norm_out(x)
x = self.nonlinearity(x)
x = self.conv_out(x)
return x + residual
class HunyuanImageDecoder2D(nn.Module):
r"""
Decoder network that reconstructs output from latent representation.
Args:
z_channels : int
Number of latent channels.
out_channels : int
Number of output channels.
block_out_channels : Tuple[int, ...]
Output channels for each block.
num_res_blocks : int
Number of residual blocks per block.
spatial_compression_ratio : int
Spatial upsampling factor.
upsample_match_channel : bool
Whether to match channels during upsampling.
non_linearity (str): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
upsample_match_channel: bool = True,
non_linearity: str = "silu",
):
super().__init__()
if block_out_channels[0] % z_channels != 0:
raise ValueError(
f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
)
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.repeat = block_out_channels[0] // z_channels
self.spatial_compression_ratio = spatial_compression_ratio
self.nonlinearity = get_activation(non_linearity)
self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# Middle blocks with attention
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
# Upsampling blocks
block_in_channel = block_out_channels[0]
self.up_blocks = nn.ModuleList()
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
for _ in range(self.num_res_blocks + 1):
self.up_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if upsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
block_in_channel = block_out_channel
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid_block, h)
else:
h = self.mid_block(h)
for up_block in self.up_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(up_block, h)
else:
h = up_block(h)
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
# fmt: off
@register_to_config
def __init__(
self,
in_channels: int,
out_channels: int,
latent_channels: int,
block_out_channels: Tuple[int, ...],
layers_per_block: int,
spatial_compression_ratio: int,
sample_size: int,
scaling_factor: float = None,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
) -> None:
# fmt: on
super().__init__()
self.encoder = HunyuanImageEncoder2D(
in_channels=in_channels,
z_channels=latent_channels,
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageDecoder2D(
z_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
# Tiling and slicing configuration
self.use_slicing = False
self.use_tiling = False
# Tiling parameters
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // spatial_compression_ratio
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_size: Optional[int] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
allow processing larger images.
Args:
tile_sample_min_size (`int`, *optional*):
The minimum size required for a sample to be separated into tiles across the spatial dimension.
tile_overlap_factor (`float`, *optional*):
The overlap factor required for a latent to be separated into tiles across the spatial dimension.
"""
self.use_tiling = True
self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self.tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True):
batch_size, num_channels, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode input using spatial tiling strategy.
Args:
x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
Returns:
`torch.Tensor`:
The latent representation of the encoded images.
"""
_, _, _, height, width = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode latent using spatial tiling strategy.
Args:
z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, height, width = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec

View File

@@ -0,0 +1,934 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageRefinerCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
super().__init__()
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.pad_mode = pad_mode
self.time_causal_padding = (
kernel_size[0] // 2,
kernel_size[0] // 2,
kernel_size[1] // 2,
kernel_size[1] // 2,
kernel_size[2] - 1,
0,
)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
return self.conv(hidden_states)
class HunyuanImageRefinerRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.norm(x)
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, frames, height, width = query.shape
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
# batch_size, 1, frames * height * width, channels
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
x = self.proj_out(x)
return x + identity
class HunyuanImageRefinerUpsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
self.add_temporal_upsample = add_temporal_upsample
self.repeats = factor * out_channels // in_channels
@staticmethod
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
Args:
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
r1: temporal upsampling factor
r2: height upsampling factor
r3: width upsampling factor
"""
b, packed_c, f, h, w = tensor.shape
factor = r1 * r2 * r3
c = packed_c // factor
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_upsample else 1
h = self.conv(x)
if self.add_temporal_upsample:
h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
h = h[:, : h.shape[1] // 2]
# shortcut computation
shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
else:
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
return h + shortcut
class HunyuanImageRefinerDownsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
self.group_size = factor * in_channels // out_channels
@staticmethod
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
This packs spatial/temporal dimensions into channels (opposite of upsample)
"""
b, c, packed_f, packed_h, packed_w = tensor.shape
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_downsample else 1
h = self.conv(x)
if self.add_temporal_downsample:
# h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
h = torch.cat([h, h], dim=1)
# shortcut computation
# shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
else:
# h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
# shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageRefinerResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
non_linearity: str = "swish",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual
class HunyuanImageRefinerMidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
add_attention: bool = True,
) -> None:
super().__init__()
self.add_attention = add_attention
# There is always at least one resnet
resnets = [
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
]
attentions = []
for _ in range(num_layers):
if self.add_attention:
attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
else:
attentions.append(None)
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states)
return hidden_states
class HunyuanImageRefinerDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample_out_channels: Optional[int] = None,
add_temporal_downsample: int = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample_out_channels is not None:
self.downsamplers = nn.ModuleList(
[
HunyuanImageRefinerDownsampleDCAE(
out_channels,
out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
upsample_out_channels: Optional[int] = None,
add_temporal_upsample: bool = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=input_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample_out_channels is not None:
self.upsamplers = nn.ModuleList(
[
HunyuanImageRefinerUpsampleDCAE(
out_channels,
out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for resnet in self.resnets:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerEncoder3D(nn.Module):
r"""
3D vae encoder for HunyuanImageRefiner.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 64,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
temporal_compression_ratio: int = 4,
spatial_compression_ratio: int = 16,
downsample_match_channel: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.group_size = block_out_channels[-1] // self.out_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
output_channel = block_out_channels[i]
if not add_spatial_downsample:
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=None,
add_temporal_downsample=False,
)
input_channel = output_channel
else:
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
input_channel = downsample_out_channels
self.down_blocks.append(down_block)
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
batch_size, _, frame, height, width = hidden_states.shape
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states += short_cut
return hidden_states
class HunyuanImageRefinerDecoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
"""
def __init__(
self,
in_channels: int = 32,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
upsample_match_channel: bool = True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.in_channels = in_channels
self.out_channels = out_channels
self.repeat = block_out_channels[0] // self.in_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
# up
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
output_channel = block_out_channels[i]
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
if add_spatial_upsample or add_temporal_upsample:
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
input_channel = upsample_out_channels
else:
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=None,
add_temporal_upsample=False,
)
input_channel = output_channel
self.up_blocks.append(up_block)
# out
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
for up_block in self.up_blocks:
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 32,
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
scaling_factor: float = 1.03682,
) -> None:
super().__init__()
self.encoder = HunyuanImageRefinerEncoder3D(
in_channels=in_channels,
out_channels=latent_channels * 2,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageRefinerDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
self.spatial_compression_ratio = spatial_compression_ratio
self.temporal_compression_ratio = temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
x = self.encoder(x)
return x
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z)
dec = self.decoder(z)
return dec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, _, height, width = x.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = z[
:,
:,
:,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec

View File

@@ -27,6 +27,7 @@ if is_torch_available():
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel

View File

@@ -0,0 +1,971 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepTextProjEmbeddings,
TimestepEmbedding,
Timesteps,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"HunyuanImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)) # batch_size, seq_len, heads, head_dim
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
if attn.add_q_proj is None and encoder_hidden_states is not None:
query = torch.cat(
[
apply_rotary_emb(
query[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1
),
query[:, -encoder_hidden_states.shape[1] :],
],
dim=1,
)
key = torch.cat(
[
apply_rotary_emb(key[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1),
key[:, -encoder_hidden_states.shape[1] :],
],
dim=1,
)
else:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)
# 5. Attention
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, : -encoder_hidden_states.shape[1]],
hidden_states[:, -encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class HunyuanImagePatchEmbed(nn.Module):
def __init__(
self,
patch_size: Union[Tuple[int, int], Tuple[int, int, int]] = (16, 16),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
self.patch_size = patch_size
if len(patch_size) == 2:
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
elif len(patch_size) == 3:
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
else:
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {len(patch_size)}")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
return hidden_states
class HunyuanImageByT5TextProjection(nn.Module):
def __init__(self, in_features: int, hidden_size: int, out_features: int):
super().__init__()
self.norm = nn.LayerNorm(in_features)
self.linear_1 = nn.Linear(in_features, hidden_size)
self.linear_2 = nn.Linear(hidden_size, hidden_size)
self.linear_3 = nn.Linear(hidden_size, out_features)
self.act_fn = nn.GELU()
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(encoder_hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_3(hidden_states)
return hidden_states
class HunyuanImageAdaNorm(nn.Module):
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
super().__init__()
out_features = out_features or 2 * in_features
self.linear = nn.Linear(in_features, out_features)
self.nonlinearity = nn.SiLU()
def forward(
self, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
temb = self.linear(self.nonlinearity(temb))
gate_msa, gate_mlp = temb.chunk(2, dim=1)
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
return gate_msa, gate_mlp
class HunyuanImageCombinedTimeGuidanceEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
guidance_embeds: bool = False,
use_meanflow: bool = False,
):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_meanflow = use_meanflow
self.time_proj_r = None
self.timestep_embedder_r = None
if use_meanflow:
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = None
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(
self,
timestep: torch.Tensor,
timestep_r: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
if timestep_r is not None:
timesteps_proj_r = self.time_proj_r(timestep_r)
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
timesteps_emb = (timesteps_emb + timesteps_emb_r) / 2
if self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=timestep.dtype))
conditioning = timesteps_emb + guidance_emb
else:
conditioning = timesteps_emb
return conditioning
# IndividualTokenRefinerBlock
@maybe_allow_in_graph
class HunyuanImageIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int, # 28
attention_head_dim: int, # 128
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
self.norm_out = HunyuanImageAdaNorm(hidden_size, 2 * hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
)
gate_msa, gate_mlp = self.norm_out(temb)
hidden_states = hidden_states + attn_output * gate_msa
ff_output = self.ff(self.norm2(hidden_states))
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
class HunyuanImageIndividualTokenRefiner(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
self.refiner_blocks = nn.ModuleList(
[
HunyuanImageIndividualTokenRefinerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
for _ in range(num_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> None:
self_attn_mask = None
if attention_mask is not None:
batch_size = attention_mask.shape[0]
seq_len = attention_mask.shape[1]
attention_mask = attention_mask.to(hidden_states.device)
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
self_attn_mask[:, :, :, 0] = True
for block in self.refiner_blocks:
hidden_states = block(hidden_states, temb, self_attn_mask)
return hidden_states
# txt_in
class HunyuanImageTokenRefiner(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=hidden_size, pooled_projection_dim=in_channels
)
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
self.token_refiner = HunyuanImageIndividualTokenRefiner(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_layers=num_layers,
mlp_width_ratio=mlp_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if attention_mask is None:
pooled_hidden_states = hidden_states.mean(dim=1)
else:
original_dtype = hidden_states.dtype
mask_float = attention_mask.float().unsqueeze(-1)
pooled_hidden_states = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
pooled_hidden_states = pooled_hidden_states.to(original_dtype)
temb = self.time_text_embed(timestep, pooled_hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
return hidden_states
class HunyuanImageRotaryPosEmbed(nn.Module):
def __init__(
self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0
) -> None:
super().__init__()
if not isinstance(patch_size, (tuple, list)) or len(patch_size) not in [2, 3]:
raise ValueError(f"patch_size must be a tuple or list of length 2 or 3, got {patch_size}")
if not isinstance(rope_dim, (tuple, list)) or len(rope_dim) not in [2, 3]:
raise ValueError(f"rope_dim must be a tuple or list of length 2 or 3, got {rope_dim}")
if not len(patch_size) == len(rope_dim):
raise ValueError(f"patch_size and rope_dim must have the same length, got {patch_size} and {rope_dim}")
self.patch_size = patch_size
self.rope_dim = rope_dim
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if hidden_states.ndim == 5:
_, _, frame, height, width = hidden_states.shape
patch_size_frame, patch_size_height, patch_size_width = self.patch_size
rope_sizes = [frame // patch_size_frame, height // patch_size_height, width // patch_size_width]
elif hidden_states.ndim == 4:
_, _, height, width = hidden_states.shape
patch_size_height, patch_size_width = self.patch_size
rope_sizes = [height // patch_size_height, width // patch_size_width]
else:
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
axes_grids = []
for i in range(len(rope_sizes)):
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
axes_grids.append(grid)
grid = torch.meshgrid(*axes_grids, indexing="ij") # dim x [H, W]
grid = torch.stack(grid, dim=0) # [2, H, W]
freqs = []
for i in range(len(rope_sizes)):
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
freqs.append(freq)
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class HunyuanImageSingleTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
mlp_dim = int(hidden_size * mlp_ratio)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
bias=True,
processor=HunyuanImageAttnProcessor(),
qk_norm=qk_norm,
eps=1e-6,
pre_only=True,
)
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
residual = hidden_states
# 1. Input normalization
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
norm_hidden_states, norm_encoder_hidden_states = (
norm_hidden_states[:, :-text_seq_length, :],
norm_hidden_states[:, -text_seq_length:, :],
)
# 2. Attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
# 3. Modulation and residual connection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
hidden_states = hidden_states + residual
hidden_states, encoder_hidden_states = (
hidden_states[:, :-text_seq_length, :],
hidden_states[:, -text_seq_length:, :],
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class HunyuanImageTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
added_kv_proj_dim=hidden_size,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
context_pre_only=False,
bias=True,
processor=HunyuanImageAttnProcessor(),
qk_norm=qk_norm,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 3. Modulation and residual connection
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return hidden_states, encoder_hidden_states
class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of dual-stream blocks to use.
num_single_layers (`int`, defaults to `40`):
The number of layers of single-stream blocks to use.
num_refiner_layers (`int`, defaults to `2`):
The number of layers of refiner blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
qk_norm (`str`, defaults to `rms_norm`):
The normalization to use for the query and key projections in the attention layers.
guidance_embeds (`bool`, defaults to `True`):
Whether to use guidance embeddings in the model.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
pooled_projection_dim (`int`, defaults to `768`):
The dimension of the pooled projection of the text embeddings.
rope_theta (`float`, defaults to `256.0`):
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
image_condition_type (`str`, *optional*, defaults to `None`):
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
tokens in the latent stream and apply conditioning.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanImageTransformerBlock",
"HunyuanImageSingleTransformerBlock",
"HunyuanImagePatchEmbed",
"HunyuanImageTokenRefiner",
]
_repeated_blocks = [
"HunyuanImageTransformerBlock",
"HunyuanImageSingleTransformerBlock",
]
@register_to_config
def __init__(
self,
in_channels: int = 64,
out_channels: int = 64,
num_attention_heads: int = 28,
attention_head_dim: int = 128,
num_layers: int = 20,
num_single_layers: int = 40,
num_refiner_layers: int = 2,
mlp_ratio: float = 4.0,
patch_size: Tuple[int, int] = (1, 1),
qk_norm: str = "rms_norm",
guidance_embeds: bool = False,
text_embed_dim: int = 3584,
text_embed_2_dim: Optional[int] = None,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (64, 64),
use_meanflow: bool = False,
) -> None:
super().__init__()
if not (isinstance(patch_size, (tuple, list)) and len(patch_size) in [2, 3]):
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {patch_size}")
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Latent and condition embedders
self.x_embedder = HunyuanImagePatchEmbed(patch_size, in_channels, inner_dim)
self.context_embedder = HunyuanImageTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
if text_embed_2_dim is not None:
self.context_embedder_2 = HunyuanImageByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
else:
self.context_embedder_2 = None
self.time_guidance_embed = HunyuanImageCombinedTimeGuidanceEmbedding(inner_dim, guidance_embeds, use_meanflow)
# 2. RoPE
self.rope = HunyuanImageRotaryPosEmbed(patch_size, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanImageTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanImageSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
timestep_r: Optional[torch.LongTensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
encoder_attention_mask_2: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
if hidden_states.ndim == 4:
batch_size, channels, height, width = hidden_states.shape
sizes = (height, width)
elif hidden_states.ndim == 5:
batch_size, channels, frame, height, width = hidden_states.shape
sizes = (frame, height, width)
else:
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
post_patch_sizes = tuple(d // p for d, p in zip(sizes, self.config.patch_size))
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
encoder_attention_mask = encoder_attention_mask.bool()
temb = self.time_guidance_embed(timestep, guidance=guidance, timestep_r=timestep_r)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
if self.context_embedder_2 is not None and encoder_hidden_states_2 is not None:
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
# reorder and combine text tokens: combine valid tokens first, then padding
new_encoder_hidden_states = []
new_encoder_attention_mask = []
for text, text_mask, text_2, text_mask_2 in zip(
encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2
):
# Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
new_encoder_hidden_states.append(
torch.cat(
[
text_2[text_mask_2], # valid byt5
text[text_mask], # valid mllm
text_2[~text_mask_2], # invalid byt5
text[~text_mask], # invalid mllm
],
dim=0,
)
)
# Apply same reordering to attention masks
new_encoder_attention_mask.append(
torch.cat(
[
text_mask_2[text_mask_2],
text_mask[text_mask],
text_mask_2[~text_mask_2],
text_mask[~text_mask],
],
dim=0,
)
)
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
attention_mask = torch.nn.functional.pad(encoder_attention_mask, (hidden_states.shape[1], 0), value=True)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# 3. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 4. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# 5. unpatchify
# reshape: [batch_size, *post_patch_dims, channels, *patch_size]
out_channels = self.config.out_channels
reshape_dims = [batch_size] + list(post_patch_sizes) + [out_channels] + list(self.config.patch_size)
hidden_states = hidden_states.reshape(*reshape_dims)
# create permutation pattern: batch, channels, then interleave post_patch and patch dims
# For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
# For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
ndim = len(post_patch_sizes)
permute_pattern = [0, ndim + 1] # batch, channels
for i in range(ndim):
permute_pattern.extend([i + 1, ndim + 2 + i]) # post_patch_sizes[i], patch_sizes[i]
hidden_states = hidden_states.permute(*permute_pattern)
# flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
# batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
final_dims = [batch_size, out_channels] + [
post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
]
hidden_states = hidden_states.reshape(*final_dims)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)

View File

@@ -130,8 +130,14 @@ class PipelineState:
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
intermediates dict.
"""
if name in self.values:
return self.values[name]
# Use object.__getattribute__ to avoid infinite recursion during deepcopy
try:
values = object.__getattribute__(self, "values")
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
if name in values:
return values[name]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __repr__(self):
@@ -2492,6 +2498,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"""
if state is None:
state = PipelineState()
else:
state = deepcopy(state)
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()

View File

@@ -238,19 +238,27 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
guider_input_fields = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
guider_inputs = {
"encoder_hidden_states": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"encoder_hidden_states_mask": (
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
# YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer(
@@ -328,19 +336,27 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
guider_input_fields = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
guider_inputs = {
"encoder_hidden_states": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"encoder_hidden_states_mask": (
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
# YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer(

View File

@@ -201,27 +201,41 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
guider_inputs = {
"prompt_embeds": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"time_ids": (
getattr(block_state, "add_time_ids", None),
getattr(block_state, "negative_add_time_ids", None),
),
"text_embeds": (
getattr(block_state, "pooled_prompt_embeds", None),
getattr(block_state, "negative_pooled_prompt_embeds", None),
),
"image_embeds": (
getattr(block_state, "ip_adapter_embeds", None),
getattr(block_state, "negative_ip_adapter_embeds", None),
),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
@@ -344,11 +358,23 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
guider_inputs = {
"prompt_embeds": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"time_ids": (
getattr(block_state, "add_time_ids", None),
getattr(block_state, "negative_add_time_ids", None),
),
"text_embeds": (
getattr(block_state, "pooled_prompt_embeds", None),
getattr(block_state, "negative_pooled_prompt_embeds", None),
),
"image_embeds": (
getattr(block_state, "ip_adapter_embeds", None),
getattr(block_state, "negative_ip_adapter_embeds", None),
),
}
# cond_scale for the timestep (controlnet input)
@@ -369,12 +395,15 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
# guided denoiser step
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:

View File

@@ -94,25 +94,30 @@ class WanLoopDenoiser(ModularPipelineBlocks):
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
guider_inputs = {
"prompt_embeds": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
}
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual

View File

@@ -241,6 +241,7 @@ else:
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
]
_import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
@@ -640,6 +641,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ReduxImageEncoder,
)
from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,

View File

@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"]
_import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_hunyuanimage import HunyuanImagePipeline
from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,866 @@
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
from ...guiders import AdaptiveProjectedMixGuidance
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import HunyuanImagePipeline
>>> pipe = HunyuanImagePipeline.from_pretrained(
... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0]
>>> image.save("hunyuanimage.png")
```
"""
def extract_glyph_text(prompt: str):
"""
Extract text enclosed in quotes for glyph rendering.
Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
Args:
prompt: Input text prompt
Returns:
Formatted glyph text string or None if no quoted text found
"""
text_prompt_texts = []
pattern_quote_single = r"\'(.*?)\'"
pattern_quote_double = r"\"(.*?)\""
pattern_quote_chinese_single = r"(.*?)"
pattern_quote_chinese_double = r"“(.*?)”"
matches_quote_single = re.findall(pattern_quote_single, prompt)
matches_quote_double = re.findall(pattern_quote_double, prompt)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
text_prompt_texts.extend(matches_quote_single)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
if text_prompt_texts:
glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". "
else:
glyph_text_formatted = None
return glyph_text_formatted
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class HunyuanImagePipeline(DiffusionPipeline):
r"""
The HunyuanImage pipeline for text-to-image generation.
Args:
transformer ([`HunyuanImageTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLHunyuanImage`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
text_encoder_2 ([`T5EncoderModel`]):
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
variant.
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
guider ([`AdaptiveProjectedMixGuidance`]):
[AdaptiveProjectedMixGuidance]to be used to guide the image generation.
ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*):
[AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_optional_components = ["ocr_guider", "guider"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLHunyuanImage,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: ByT5Tokenizer,
transformer: HunyuanImageTransformer2DModel,
guider: Optional[AdaptiveProjectedMixGuidance] = None,
ocr_guider: Optional[AdaptiveProjectedMixGuidance] = None,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
guider=guider,
ocr_guider=ocr_guider,
)
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = 1000
self.tokenizer_2_max_length = 128
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.prompt_template_encode_start_idx = 34
self.default_sample_size = 64
def _get_qwen_prompt_embeds(
self,
tokenizer: Qwen2Tokenizer,
text_encoder: Qwen2_5_VLForConditionalGeneration,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 1000,
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
drop_idx: int = 34,
hidden_state_skip_layer: int = 2,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
prompt_embeds = prompt_embeds[:, drop_idx:]
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = encoder_attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def _get_byt5_prompt_embeds(
self,
tokenizer: ByT5Tokenizer,
text_encoder: T5EncoderModel,
prompt: str,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 128,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
if isinstance(prompt, list):
raise ValueError("byt5 prompt should be a string")
elif prompt is None:
raise ValueError("byt5 prompt should not be None")
txt_tokens = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).to(device)
prompt_embeds = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask.float(),
)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = txt_tokens.attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
batch_size: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
batch_size (`int`):
batch size of prompts, defaults to 1
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
"""
device = device or self._execution_device
if prompt is None:
prompt = [""] * batch_size
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
prompt=prompt,
device=device,
tokenizer_max_length=self.tokenizer_max_length,
template=self.prompt_template_encode,
drop_idx=self.prompt_template_encode_start_idx,
)
if prompt_embeds_2 is None:
prompt_embeds_2_list = []
prompt_embeds_mask_2_list = []
glyph_texts = [extract_glyph_text(p) for p in prompt]
for glyph_text in glyph_texts:
if glyph_text is None:
glyph_text_embeds = torch.zeros(
(1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device
)
glyph_text_embeds_mask = torch.zeros(
(1, self.tokenizer_2_max_length), device=device, dtype=torch.int64
)
else:
glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds(
tokenizer=self.tokenizer_2,
text_encoder=self.text_encoder_2,
prompt=glyph_text,
device=device,
tokenizer_max_length=self.tokenizer_2_max_length,
)
prompt_embeds_2_list.append(glyph_text_embeds)
prompt_embeds_mask_2_list.append(glyph_text_embeds_mask)
prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0)
prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
_, seq_len_2, _ = prompt_embeds_2.shape
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1)
prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1)
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2)
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
prompt_embeds_2=None,
prompt_embeds_mask_2=None,
negative_prompt_embeds_2=None,
negative_prompt_embeds_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
raise ValueError(
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
)
if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
raise ValueError(
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
distilled_guidance_scale: Optional[float] = 3.25,
sigmas: Optional[List[float]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
not provided, will use an empty negative prompt. Ignored when not using guidance. ).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
distilled_guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
ignored.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, text embeddings mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input
argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from
`negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds_2=prompt_embeds_2,
prompt_embeds_mask_2=prompt_embeds_mask_2,
negative_prompt_embeds_2=negative_prompt_embeds_2,
negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
)
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. prepare prompt embeds
prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds_2=prompt_embeds_2,
prompt_embeds_mask_2=prompt_embeds_mask_2,
)
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
# select guider
if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None:
# prompt contains ocr and pipeline has a guider for ocr
guider = self.ocr_guider
elif self.guider is not None:
guider = self.guider
# distilled model does not use guidance method, use default guider with enabled=False
else:
guider = AdaptiveProjectedMixGuidance(enabled=False)
if guider._enabled and guider.num_conditions > 1:
(
negative_prompt_embeds,
negative_prompt_embeds_mask,
negative_prompt_embeds_2,
negative_prompt_embeds_mask_2,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds_2=negative_prompt_embeds_2,
prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
)
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance (for guidance-distilled model)
if self.transformer.config.guidance_embeds and distilled_guidance_scale is None:
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
if self.transformer.config.guidance_embeds:
guidance = (
torch.tensor(
[distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
)
* 1000.0
)
else:
guidance = None
if self.attention_kwargs is None:
self._attention_kwargs = {}
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if self.transformer.config.use_meanflow:
if i == len(timesteps) - 1:
timestep_r = torch.tensor([0.0], device=device)
else:
timestep_r = timesteps[i + 1]
timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
else:
timestep_r = None
# Step 1: Collect model inputs needed for the guidance method
# conditional inputs should always be first element in the tuple
guider_inputs = {
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
}
# Step 2: Update guider's internal state for this denoising step
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
# Step 3: Prepare batched model inputs based on the guidance method
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = guider.prepare_inputs(guider_inputs)
# Step 4: Run the denoiser for each batch
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
for guider_state_batch in guider_state:
guider.prepare_models(self.transformer)
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
cond_kwargs = {
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
}
# e.g. "pred_cond"/"pred_uncond"
context_name = getattr(guider_state_batch, guider._identifier_key)
with self.transformer.cache_context(context_name):
# Run denoiser and store noise prediction in this batch
guider_state_batch.noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep,
timestep_r=timestep_r,
guidance=guidance,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
# Cleanup model (e.g., remove hooks)
guider.cleanup_models(self.transformer)
# Step 5: Combine predictions using the guidance method
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
# Continuing the CFG example, the guider receives:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
# ]
# And extracts predictions using the __guidance_identifier__:
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
# Then applies CFG formula:
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
noise_pred = guider(guider_state)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return HunyuanImagePipelineOutput(images=image)

View File

@@ -0,0 +1,752 @@
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from ...guiders import AdaptiveProjectedMixGuidance
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import HunyuanImageRefinerPipeline
>>> pipe = HunyuanImageRefinerPipeline.from_pretrained(
... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> image = load_image("path/to/image.png")
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, image=image, num_inference_steps=4).images[0]
>>> image.save("hunyuanimage.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class HunyuanImageRefinerPipeline(DiffusionPipeline):
r"""
The HunyuanImage pipeline for text-to-image generation.
Args:
transformer ([`HunyuanImageTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLHunyuanImageRefiner`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_optional_components = ["guider"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLHunyuanImageRefiner,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
transformer: HunyuanImageTransformer2DModel,
guider: Optional[AdaptiveProjectedMixGuidance] = None,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
guider=guider,
)
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = 256
self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
self.prompt_template_encode_start_idx = 36
self.default_sample_size = 64
self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64
# Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds
def _get_qwen_prompt_embeds(
self,
tokenizer: Qwen2Tokenizer,
text_encoder: Qwen2_5_VLForConditionalGeneration,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 1000,
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
drop_idx: int = 34,
hidden_state_skip_layer: int = 2,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
prompt_embeds = prompt_embeds[:, drop_idx:]
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = encoder_attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def encode_prompt(
self,
prompt: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
batch_size: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
batch_size (`int`):
batch size of prompts, defaults to 1
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
"""
device = device or self._execution_device
if prompt is None:
prompt = [""] * batch_size
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
prompt=prompt,
device=device,
tokenizer_max_length=self.tokenizer_max_length,
template=self.prompt_template_encode,
drop_idx=self.prompt_template_encode_start_idx,
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
def prepare_latents(
self,
image_latents,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
strength=0.25,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, 1, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_latents = strength * noise + (1 - strength) * image_latents
return latents, cond_latents
@staticmethod
def _reorder_image_tokens(image_latents):
image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2)
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape
image_latents = image_latents.permute(0, 2, 1, 3, 4)
image_latents = image_latents.reshape(
batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width
)
image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous()
return image_latents
@staticmethod
def _restore_image_tokens_order(latents):
"""Restore image tokens order by splitting channels and removing first frame slice."""
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W
latents = latents.reshape(
batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width
) # B, F*2, C//2, H, W
latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W
# Remove first frame slice
latents = latents[:, :, 1:]
return latents
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample")
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample")
image_latents = self._reorder_image_tokens(image_latents)
image_latents = image_latents * self.vae.config.scaling_factor
return image_latents
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
distilled_guidance_scale: Optional[float] = 3.25,
image: Optional[PipelineImageInput] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 4,
sigmas: Optional[List[float]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, will use an empty negative
prompt. Ignored when not using guidance.
distilled_guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
ignored.
num_images_per_prompt (`int`, *optional*, defaults to 1):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. process image
if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels:
image_latents = image
else:
image = self.image_processor.preprocess(image, height, width)
image = image.unsqueeze(2).to(device, dtype=self.vae.dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
# 3.prepare prompt embeds
if self.guider is not None:
guider = self.guider
else:
# distilled model does not use guidance method, use default guider with enabled=False
guider = AdaptiveProjectedMixGuidance(enabled=False)
requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
if requires_unconditional_embeds:
(
negative_prompt_embeds,
negative_prompt_embeds_mask,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
)
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
# 4. Prepare latent variables
latents, cond_latents = self.prepare_latents(
image_latents=image_latents,
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=self.latent_channels,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance (this pipeline only supports guidance-distilled models)
if distilled_guidance_scale is None:
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
guidance = (
torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device)
* 1000.0
)
if self.attention_kwargs is None:
self._attention_kwargs = {}
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# Step 1: Collect model inputs needed for the guidance method
# conditional inputs should always be first element in the tuple
guider_inputs = {
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
}
# Step 2: Update guider's internal state for this denoising step
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
# Step 3: Prepare batched model inputs based on the guidance method
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = guider.prepare_inputs(guider_inputs)
# Step 4: Run the denoiser for each batch
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
for guider_state_batch in guider_state:
guider.prepare_models(self.transformer)
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
cond_kwargs = {
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
}
# e.g. "pred_cond"/"pred_uncond"
context_name = getattr(guider_state_batch, guider._identifier_key)
with self.transformer.cache_context(context_name):
# Run denoiser and store noise prediction in this batch
guider_state_batch.noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
guidance=guidance,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
# Cleanup model (e.g., remove hooks)
guider.cleanup_models(self.transformer)
# Step 5: Combine predictions using the guidance method
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
# Continuing the CFG example, the guider receives:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
# ]
# And extracts predictions using the __guidance_identifier__:
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
# Then applies CFG formula:
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
noise_pred = guider(guider_state)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
latents = self._restore_image_tokens_order(latents)
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return HunyuanImagePipelineOutput(images=image)

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class HunyuanImagePipelineOutput(BaseOutput):
"""
Output class for HunyuanImage pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -76,6 +76,7 @@ LOADABLE_CLASSES = {
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
"BaseGuidance": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],

View File

@@ -17,6 +17,21 @@ class AdaptiveProjectedGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AdaptiveProjectedMixGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoGuidance(metaclass=DummyObject):
_backends = ["torch"]
@@ -32,6 +47,21 @@ class AutoGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BaseGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ClassifierFreeGuidance(metaclass=DummyObject):
_backends = ["torch"]
@@ -378,6 +408,36 @@ class AutoencoderKLCosmos(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLHunyuanImage(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLHunyuanImageRefiner(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
_backends = ["torch"]
@@ -858,6 +918,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class HunyuanImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1037,6 +1037,36 @@ class HunyuanDiTPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class HunyuanImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanImageRefinerPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,290 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import (
ByT5Tokenizer,
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
Qwen2Tokenizer,
T5Config,
T5EncoderModel,
)
from diffusers import (
AdaptiveProjectedMixGuidance,
AutoencoderKLHunyuanImage,
FlowMatchEulerDiscreteScheduler,
HunyuanImagePipeline,
HunyuanImageTransformer2DModel,
)
from ...testing_utils import enable_full_determinism
from ..test_pipelines_common import (
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
to_np,
)
enable_full_determinism()
class HunyuanImagePipelineFastTests(
PipelineTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = HunyuanImagePipeline
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt", "negative_prompt"])
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
test_attention_slicing = False
supports_dduf = False
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1, guidance_embeds: bool = False):
torch.manual_seed(0)
transformer = HunyuanImageTransformer2DModel(
in_channels=4,
out_channels=4,
num_attention_heads=4,
attention_head_dim=8,
num_layers=num_layers,
num_single_layers=num_single_layers,
num_refiner_layers=1,
patch_size=(1, 1),
guidance_embeds=guidance_embeds,
text_embed_dim=32,
text_embed_2_dim=32,
rope_axes_dim=(4, 4),
)
torch.manual_seed(0)
vae = AutoencoderKLHunyuanImage(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(32, 64, 64, 64),
layers_per_block=1,
scaling_factor=0.476986,
spatial_compression_ratio=8,
sample_size=128,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
if not guidance_embeds:
torch.manual_seed(0)
guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=2)
ocr_guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=3)
else:
guider = None
ocr_guider = None
torch.manual_seed(0)
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 32,
"intermediate_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [2, 2, 4],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
},
vision_config={
"depth": 2,
"hidden_size": 32,
"intermediate_size": 32,
"num_heads": 2,
"out_hidden_size": 32,
},
hidden_size=32,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
torch.manual_seed(0)
t5_config = T5Config(
d_model=32,
d_kv=4,
d_ff=16,
num_layers=2,
num_heads=2,
relative_attention_num_buckets=8,
relative_attention_max_distance=32,
vocab_size=256,
feed_forward_proj="gated-gelu",
dense_act_fn="gelu_new",
is_encoder_decoder=False,
use_cache=False,
tie_word_embeddings=False,
)
text_encoder_2 = T5EncoderModel(t5_config)
tokenizer_2 = ByT5Tokenizer()
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"guider": guider,
"ocr_guider": ocr_guider,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 5,
"height": 16,
"width": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 16, 16))
expected_slice_np = np.array(
[0.6252659, 0.51482046, 0.60799813, 0.59267783, 0.488082, 0.5857634, 0.523781, 0.58028054, 0.5674121]
)
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
self.assertTrue(
np.abs(output_slice - expected_slice_np).max() < 1e-3,
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
)
def test_inference_guider(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipe.guider = pipe.guider.new(guidance_scale=1000)
pipe.ocr_guider = pipe.ocr_guider.new(guidance_scale=1000)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 16, 16))
expected_slice_np = np.array(
[0.61494756, 0.49616697, 0.60327923, 0.6115793, 0.49047345, 0.56977504, 0.53066164, 0.58880305, 0.5570612]
)
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
self.assertTrue(
np.abs(output_slice - expected_slice_np).max() < 1e-3,
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
)
def test_inference_with_distilled_guidance(self):
device = "cpu"
components = self.get_dummy_components(guidance_embeds=True)
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["distilled_guidance_scale"] = 3.5
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 16, 16))
expected_slice_np = np.array(
[0.63667065, 0.5187377, 0.66757566, 0.6320319, 0.4913387, 0.54813194, 0.5335031, 0.5736143, 0.5461346]
)
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
self.assertTrue(
np.abs(output_slice - expected_slice_np).max() < 1e-3,
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(tile_sample_min_size=96)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
@unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
def test_encode_prompt_works_in_isolation(self):
pass