mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
HunyuanImage21 (#12333)
* add hunyuanimage2.1 --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -347,6 +347,8 @@
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/hunyuanimage_transformer_2d
|
||||
title: HunyuanImageTransformer2DModel
|
||||
- local: api/models/hunyuan_video_transformer_3d
|
||||
title: HunyuanVideoTransformer3DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
@@ -411,6 +413,10 @@
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoderkl_cosmos
|
||||
title: AutoencoderKLCosmos
|
||||
- local: api/models/autoencoder_kl_hunyuanimage
|
||||
title: AutoencoderKLHunyuanImage
|
||||
- local: api/models/autoencoder_kl_hunyuanimage_refiner
|
||||
title: AutoencoderKLHunyuanImageRefiner
|
||||
- local: api/models/autoencoder_kl_hunyuan_video
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
@@ -620,6 +626,8 @@
|
||||
title: ConsisID
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/hunyuanimage21
|
||||
title: HunyuanImage2.1
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
|
||||
32
docs/source/en/api/models/autoencoder_kl_hunyuanimage.md
Normal file
32
docs/source/en/api/models/autoencoder_kl_hunyuanimage.md
Normal file
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanImage
|
||||
|
||||
The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanImage
|
||||
|
||||
vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanImage
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanImage
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanImageRefiner
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
30
docs/source/en/api/models/hunyuanimage_transformer_2d.md
Normal file
30
docs/source/en/api/models/hunyuanimage_transformer_2d.md
Normal file
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HunyuanImageTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanImageTransformer2DModel
|
||||
|
||||
transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanImageTransformer2DModel
|
||||
|
||||
[[autodoc]] HunyuanImageTransformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
152
docs/source/en/api/pipelines/hunyuanimage21.md
Normal file
152
docs/source/en/api/pipelines/hunyuanimage21.md
Normal file
@@ -0,0 +1,152 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# HunyuanImage2.1
|
||||
|
||||
|
||||
HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
|
||||
|
||||
HunyuanImage-2.1 comes in the following variants:
|
||||
|
||||
| model type | model id |
|
||||
|:----------:|:--------:|
|
||||
| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
|
||||
| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
|
||||
| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
|
||||
|
||||
> [!TIP]
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## HunyuanImage-2.1
|
||||
|
||||
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
|
||||
pipe = HunyuanImagePipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
```
|
||||
|
||||
You can inspect the `guider` object:
|
||||
|
||||
```py
|
||||
>>> pipe.guider
|
||||
AdaptiveProjectedMixGuidance {
|
||||
"_class_name": "AdaptiveProjectedMixGuidance",
|
||||
"_diffusers_version": "0.36.0.dev0",
|
||||
"adaptive_projected_guidance_momentum": -0.5,
|
||||
"adaptive_projected_guidance_rescale": 10.0,
|
||||
"adaptive_projected_guidance_scale": 10.0,
|
||||
"adaptive_projected_guidance_start_step": 5,
|
||||
"enabled": true,
|
||||
"eta": 0.0,
|
||||
"guidance_rescale": 0.0,
|
||||
"guidance_scale": 3.5,
|
||||
"start": 0.0,
|
||||
"stop": 1.0,
|
||||
"use_original_formulation": false
|
||||
}
|
||||
|
||||
State:
|
||||
step: None
|
||||
num_inference_steps: None
|
||||
timestep: None
|
||||
count_prepared: 0
|
||||
enabled: True
|
||||
num_conditions: 2
|
||||
momentum_buffer: None
|
||||
is_apg_enabled: False
|
||||
is_cfg_enabled: True
|
||||
```
|
||||
|
||||
To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
|
||||
pipe = HunyuanImagePipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# Update the guider configuration
|
||||
pipe.guider = pipe.guider.new(guidance_scale=5.0)
|
||||
|
||||
prompt = (
|
||||
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
|
||||
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
|
||||
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=50,
|
||||
height=2048,
|
||||
width=2048,
|
||||
).images[0]
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
|
||||
## HunyuanImage-2.1-Distilled
|
||||
|
||||
use `distilled_guidance_scale` with the guidance-distilled checkpoint,
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = (
|
||||
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
|
||||
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
|
||||
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
|
||||
)
|
||||
|
||||
out = pipe(
|
||||
prompt,
|
||||
num_inference_steps=8,
|
||||
distilled_guidance_scale=3.25,
|
||||
height=2048,
|
||||
width=2048,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
|
||||
```
|
||||
|
||||
|
||||
## HunyuanImagePipeline
|
||||
|
||||
[[autodoc]] HunyuanImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HunyuanImageRefinerPipeline
|
||||
|
||||
[[autodoc]] HunyuanImageRefinerPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## HunyuanImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput
|
||||
1044
scripts/convert_hunyuan_image_to_diffusers.py
Normal file
1044
scripts/convert_hunyuan_image_to_diffusers.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -149,7 +149,9 @@ else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"AdaptiveProjectedMixGuidance",
|
||||
"AutoGuidance",
|
||||
"BaseGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"FrequencyDecoupledGuidance",
|
||||
@@ -184,6 +186,8 @@ else:
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLCosmos",
|
||||
"AutoencoderKLHunyuanImage",
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
@@ -216,6 +220,7 @@ else:
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"HunyuanImageTransformer2DModel",
|
||||
"HunyuanVideoFramepackTransformer3DModel",
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
@@ -462,6 +467,8 @@ else:
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"HunyuanImagePipeline",
|
||||
"HunyuanImageRefinerPipeline",
|
||||
"HunyuanSkyreelsImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
@@ -849,7 +856,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
AdaptiveProjectedMixGuidance,
|
||||
AutoGuidance,
|
||||
BaseGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
@@ -880,6 +889,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -912,6 +923,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
@@ -1128,6 +1140,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
HunyuanImagePipeline,
|
||||
HunyuanImageRefinerPipeline,
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoFramepackPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
|
||||
@@ -14,28 +14,24 @@
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
from ..utils import is_torch_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logger.warning(
|
||||
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .guider_utils import BaseGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
||||
|
||||
GuiderType = Union[
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -65,8 +65,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
@@ -76,19 +77,14 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
@@ -152,6 +148,44 @@ class MomentumBuffer:
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
|
||||
"""
|
||||
if isinstance(self.running_average, torch.Tensor):
|
||||
shape = tuple(self.running_average.shape)
|
||||
|
||||
# Calculate statistics
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": self.running_average.mean().item(),
|
||||
"std": self.running_average.std().item(),
|
||||
"min": self.running_average.min().item(),
|
||||
"max": self.running_average.max().item(),
|
||||
}
|
||||
|
||||
# Get a slice (max 3 elements per dimension)
|
||||
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
|
||||
sliced_data = self.running_average[slice_indices]
|
||||
|
||||
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
|
||||
slice_str = str(sliced_data.detach().float().cpu().numpy())
|
||||
if len(slice_str) > 200: # Truncate if too long
|
||||
slice_str = slice_str[:200] + "..."
|
||||
|
||||
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
|
||||
|
||||
return (
|
||||
f"MomentumBuffer(\n"
|
||||
f" momentum={self.momentum},\n"
|
||||
f" shape={shape},\n"
|
||||
f" stats=[{stats_str}],\n"
|
||||
f" slice={slice_str}\n"
|
||||
f")"
|
||||
)
|
||||
else:
|
||||
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
284
src/diffusers/guiders/adaptive_projected_guidance_mix.py
Normal file
284
src/diffusers/guiders/adaptive_projected_guidance_mix.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AdaptiveProjectedMixGuidance(BaseGuidance):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
|
||||
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
|
||||
improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
|
||||
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
|
||||
Steps are Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
|
||||
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
|
||||
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
|
||||
used, and momentum buffer is updated).
|
||||
enabled (`bool`, defaults to `True`):
|
||||
Whether this guidance is enabled.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 3.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
adaptive_projected_guidance_scale: float = 10.0,
|
||||
adaptive_projected_guidance_momentum: float = -0.5,
|
||||
adaptive_projected_guidance_rescale: float = 10.0,
|
||||
eta: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
adaptive_projected_guidance_start_step: int = 5,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
# no guidance
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
|
||||
# CFG + update momentum buffer
|
||||
elif not self._is_apg_enabled():
|
||||
if self.momentum_buffer is not None:
|
||||
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
|
||||
# CFG + update momentum buffer
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
# APG
|
||||
elif self._is_apg_enabled():
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.adaptive_projected_guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_apg_enabled() or self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
return False
|
||||
|
||||
is_within_range = False
|
||||
if self._step is not None:
|
||||
is_within_range = self._step > self.adaptive_projected_guidance_start_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def get_state(self):
|
||||
state = super().get_state()
|
||||
state["momentum_buffer"] = self.momentum_buffer
|
||||
state["is_apg_enabled"] = self._is_apg_enabled()
|
||||
state["is_cfg_enabled"] = self._is_cfg_enabled()
|
||||
return state
|
||||
|
||||
|
||||
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
|
||||
"""
|
||||
if isinstance(self.running_average, torch.Tensor):
|
||||
shape = tuple(self.running_average.shape)
|
||||
|
||||
# Calculate statistics
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": self.running_average.mean().item(),
|
||||
"std": self.running_average.std().item(),
|
||||
"min": self.running_average.min().item(),
|
||||
"max": self.running_average.max().item(),
|
||||
}
|
||||
|
||||
# Get a slice (max 3 elements per dimension)
|
||||
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
|
||||
sliced_data = self.running_average[slice_indices]
|
||||
|
||||
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
|
||||
slice_str = str(sliced_data.detach().float().cpu().numpy())
|
||||
if len(slice_str) > 200: # Truncate if too long
|
||||
slice_str = slice_str[:200] + "..."
|
||||
|
||||
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
|
||||
|
||||
return (
|
||||
f"MomentumBuffer(\n"
|
||||
f" momentum={self.momentum},\n"
|
||||
f" shape={shape},\n"
|
||||
f" stats=[{stats_str}],\n"
|
||||
f" slice={slice_str}\n"
|
||||
f")"
|
||||
)
|
||||
else:
|
||||
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
|
||||
|
||||
|
||||
def update_momentum_buffer(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
if momentum_buffer is not None:
|
||||
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
|
||||
diff = momentum_buffer.running_average
|
||||
else:
|
||||
diff = pred_cond - pred_uncond
|
||||
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
@@ -72,8 +72,9 @@ class AutoGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.auto_guidance_layers = auto_guidance_layers
|
||||
@@ -132,16 +133,11 @@ class AutoGuidance(BaseGuidance):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,43 +27,50 @@ if TYPE_CHECKING:
|
||||
|
||||
class ClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
Implements Classifier-Free Guidance (CFG) for diffusion models.
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
|
||||
proposes scaling and shifting the conditional distribution based on the difference between conditional and
|
||||
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
Reference: https://huggingface.co/papers/2207.12598
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
CFG improves generation quality and prompt adherence by jointly training models on both conditional and
|
||||
unconditional data, then combining predictions during inference. This allows trading off between quality (high
|
||||
guidance) and diversity (low guidance).
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
**Two CFG Formulations:**
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
1. **Original formulation** (from paper):
|
||||
```
|
||||
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
|
||||
```
|
||||
Moves conditional predictions further from unconditional ones.
|
||||
|
||||
2. **Diffusers-native formulation** (default, from Imagen paper):
|
||||
```
|
||||
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
|
||||
```
|
||||
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
|
||||
quality", "watermarks"). Equivalent in theory but more intuitive.
|
||||
|
||||
Use `use_original_formulation=True` to switch to the original formulation.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
|
||||
may reduce quality. Typical range: 1.0-20.0.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
|
||||
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
|
||||
to 1.0 (full rescaling).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
|
||||
diffusers-native formulation from the Imagen paper.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
|
||||
steps.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
|
||||
steps.
|
||||
enabled (`bool`, defaults to `True`):
|
||||
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
@@ -76,23 +83,19 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -68,31 +68,31 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
# YiYi Notes: add default behavior for self._enabled == False
|
||||
if not self._enabled:
|
||||
pred = pred_cond
|
||||
|
||||
elif self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
|
||||
@@ -149,6 +149,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
enabled: bool = True,
|
||||
):
|
||||
if not _CAN_USE_KORNIA:
|
||||
raise ImportError(
|
||||
@@ -160,7 +161,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
||||
min_start = start if isinstance(start, float) else min(start)
|
||||
max_stop = stop if isinstance(stop, float) else max(stop)
|
||||
super().__init__(min_start, max_stop)
|
||||
super().__init__(min_start, max_stop, enabled)
|
||||
|
||||
self.guidance_scales = guidance_scales
|
||||
self.levels = len(guidance_scales)
|
||||
@@ -217,16 +218,11 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
_input_predictions = None
|
||||
_identifier_key = "__guidance_identifier__"
|
||||
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._step: int = None
|
||||
@@ -48,7 +48,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = True
|
||||
self._enabled = enabled
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
||||
@@ -60,6 +60,31 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def new(self, **kwargs):
|
||||
"""
|
||||
Creates a copy of this guider instance, optionally with modified configuration parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
|
||||
returns an exact copy with the same configuration.
|
||||
|
||||
Returns:
|
||||
A new guider instance with the same (or updated) configuration.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Create a CFG guider
|
||||
guider = ClassifierFreeGuidance(guidance_scale=3.5)
|
||||
|
||||
# Create an exact copy
|
||||
same_guider = guider.new()
|
||||
|
||||
# Create a copy with different start step, keeping other config the same
|
||||
new_guider = guider.new(guidance_scale=5)
|
||||
```
|
||||
"""
|
||||
return self.__class__.from_config(self.config, **kwargs)
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
@@ -72,42 +97,52 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
|
||||
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
|
||||
the values of the provided keyword arguments to this method.
|
||||
|
||||
Args:
|
||||
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation.
|
||||
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with a
|
||||
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
|
||||
conditional data identifier and the second element must be the unconditional data identifier or None.
|
||||
|
||||
Example:
|
||||
```
|
||||
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
||||
|
||||
BaseGuidance.set_input_fields(
|
||||
latents="latents",
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
)
|
||||
```
|
||||
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
|
||||
the __repr__ method. Returns:
|
||||
`Dict[str, Any]`: A dictionary containing the current state variables including:
|
||||
- step: Current inference step
|
||||
- num_inference_steps: Total number of inference steps
|
||||
- timestep: Current timestep tensor
|
||||
- count_prepared: Number of times prepare_models has been called
|
||||
- enabled: Whether the guidance is enabled
|
||||
- num_conditions: Number of conditions
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
is_string = isinstance(value, str)
|
||||
is_tuple_of_str_with_len_2 = (
|
||||
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
||||
)
|
||||
if not (is_string or is_tuple_of_str_with_len_2):
|
||||
raise ValueError(
|
||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||
)
|
||||
self._input_fields = kwargs
|
||||
state = {
|
||||
"step": self._step,
|
||||
"num_inference_steps": self._num_inference_steps,
|
||||
"timestep": self._timestep,
|
||||
"count_prepared": self._count_prepared,
|
||||
"enabled": self._enabled,
|
||||
"num_conditions": self.num_conditions,
|
||||
}
|
||||
return state
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation of the guidance object including both config and current state.
|
||||
"""
|
||||
# Get ConfigMixin's __repr__
|
||||
str_repr = super().__repr__()
|
||||
|
||||
# Get current state
|
||||
state = self.get_state()
|
||||
|
||||
# Format each state variable on its own line with indentation
|
||||
state_lines = []
|
||||
for k, v in state.items():
|
||||
# Convert value to string and handle multi-line values
|
||||
v_str = str(v)
|
||||
if "\n" in v_str:
|
||||
# For multi-line values (like MomentumBuffer), indent subsequent lines
|
||||
v_lines = v_str.split("\n")
|
||||
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
|
||||
state_lines.append(f" {k}: {v_str}")
|
||||
|
||||
state_str = "\n".join(state_lines)
|
||||
|
||||
return f"{str_repr}\nState:\n{state_str}"
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
@@ -155,8 +190,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
@classmethod
|
||||
def _prepare_batch(
|
||||
cls,
|
||||
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
||||
data: "BlockState",
|
||||
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
@@ -182,21 +216,16 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
if input_fields is None:
|
||||
raise ValueError(
|
||||
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
|
||||
)
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
for key, value in data.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
if isinstance(value, torch.Tensor):
|
||||
data_batch[key] = value
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
data_batch[key] = value[tuple_index]
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
raise ValueError(f"Invalid value type: {type(value)}")
|
||||
except ValueError:
|
||||
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@@ -98,8 +98,9 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
||||
@@ -168,12 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -186,8 +182,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -100,8 +100,9 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
@@ -164,12 +165,7 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -182,8 +178,8 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -92,8 +92,9 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.seg_guidance_scale = seg_guidance_scale
|
||||
@@ -153,12 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -171,8 +167,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -58,23 +58,19 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
||||
from ..models.transformers.transformer_flux import FluxAttnProcessor
|
||||
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
|
||||
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
||||
|
||||
@@ -149,6 +150,14 @@ def _register_attention_processors_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanImageAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=HunyuanImageAttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
@@ -162,6 +171,10 @@ def _register_transformer_blocks_metadata():
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_hunyuanimage import (
|
||||
HunyuanImageSingleTransformerBlock,
|
||||
HunyuanImageTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
@@ -283,6 +296,22 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanImage2.1
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanImageTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanImageSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
@@ -308,4 +337,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
|
||||
# not sure what this is yet.
|
||||
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
@@ -36,6 +36,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
@@ -91,6 +93,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
@@ -133,6 +136,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -182,6 +187,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
|
||||
@@ -5,6 +5,8 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
|
||||
709
src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
Normal file
709
src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
Normal file
@@ -0,0 +1,709 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanImageResnetBlock(nn.Module):
|
||||
r"""
|
||||
Residual block with two convolutions and optional channel change.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# layers
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
def forward(self, x):
|
||||
# Apply shortcut connection
|
||||
residual = x
|
||||
|
||||
# First normalization and activation
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
x = self.conv_shortcut(x)
|
||||
# Add residual connection
|
||||
return x + residual
|
||||
|
||||
|
||||
class HunyuanImageAttentionBlock(nn.Module):
|
||||
r"""
|
||||
Self-attention with a single head.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
|
||||
# layers
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.to_q = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.to_k = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.to_v = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.proj = nn.Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.norm(x)
|
||||
|
||||
# compute query, key, value
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
batch_size, channels, height, width = query.shape
|
||||
query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
|
||||
key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
|
||||
value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
|
||||
# output projection
|
||||
x = self.proj(x)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
class HunyuanImageDownsample(nn.Module):
|
||||
"""
|
||||
Downsampling block for spatial reduction.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
factor = 4
|
||||
if out_channels % factor != 0:
|
||||
raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.conv(x)
|
||||
|
||||
B, C, H, W = h.shape
|
||||
h = h.reshape(B, C, H // 2, 2, W // 2, 2)
|
||||
h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
|
||||
h = h.reshape(B, 4 * C, H // 2, W // 2)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
|
||||
shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
|
||||
shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
|
||||
|
||||
B, C, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanImageUpsample(nn.Module):
|
||||
"""
|
||||
Upsampling block for spatial expansion.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
factor = 4
|
||||
self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.conv(x)
|
||||
|
||||
B, C, H, W = h.shape
|
||||
h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
|
||||
h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
|
||||
h = h.reshape(B, C // 4, H * 2, W * 2)
|
||||
|
||||
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
|
||||
B, C, H, W = shortcut.shape
|
||||
shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
|
||||
shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
|
||||
shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanImageMidBlock(nn.Module):
|
||||
"""
|
||||
Middle block for HunyuanImageVAE encoder and decoder.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
num_layers (int): Number of layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, num_layers: int = 1):
|
||||
super().__init__()
|
||||
|
||||
resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
|
||||
|
||||
attentions = []
|
||||
for _ in range(num_layers):
|
||||
attentions.append(HunyuanImageAttentionBlock(in_channels))
|
||||
resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.resnets[0](x)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
x = attn(x)
|
||||
x = resnet(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class HunyuanImageEncoder2D(nn.Module):
|
||||
r"""
|
||||
Encoder network that compresses input to latent representation.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
z_channels (int): Number of latent channels.
|
||||
block_out_channels (list of int): Output channels for each block.
|
||||
num_res_blocks (int): Number of residual blocks per block.
|
||||
spatial_compression_ratio (int): Spatial downsampling factor.
|
||||
non_linearity (str): Type of non-linearity to use. Default is "silu".
|
||||
downsample_match_channel (bool): Whether to match channels during downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
spatial_compression_ratio: int,
|
||||
non_linearity: str = "silu",
|
||||
downsample_match_channel: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if block_out_channels[-1] % (2 * z_channels) != 0:
|
||||
raise ValueError(
|
||||
f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
|
||||
)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.spatial_compression_ratio = spatial_compression_ratio
|
||||
|
||||
self.group_size = block_out_channels[-1] // (2 * z_channels)
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# init block
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
block_in_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
block_out_channel = block_out_channels[i]
|
||||
# residual blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(
|
||||
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
|
||||
)
|
||||
block_in_channel = block_out_channel
|
||||
|
||||
# downsample block
|
||||
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
|
||||
if downsample_match_channel:
|
||||
block_out_channel = block_out_channels[i + 1]
|
||||
self.down_blocks.append(
|
||||
HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
|
||||
)
|
||||
block_in_channel = block_out_channel
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
|
||||
|
||||
# output blocks
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## downsamples
|
||||
for down_block in self.down_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x = self._gradient_checkpointing_func(down_block, x)
|
||||
else:
|
||||
x = down_block(x)
|
||||
|
||||
## middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x = self._gradient_checkpointing_func(self.mid_block, x)
|
||||
else:
|
||||
x = self.mid_block(x)
|
||||
|
||||
## head
|
||||
B, C, H, W = x.shape
|
||||
residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
|
||||
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv_out(x)
|
||||
return x + residual
|
||||
|
||||
|
||||
class HunyuanImageDecoder2D(nn.Module):
|
||||
r"""
|
||||
Decoder network that reconstructs output from latent representation.
|
||||
|
||||
Args:
|
||||
z_channels : int
|
||||
Number of latent channels.
|
||||
out_channels : int
|
||||
Number of output channels.
|
||||
block_out_channels : Tuple[int, ...]
|
||||
Output channels for each block.
|
||||
num_res_blocks : int
|
||||
Number of residual blocks per block.
|
||||
spatial_compression_ratio : int
|
||||
Spatial upsampling factor.
|
||||
upsample_match_channel : bool
|
||||
Whether to match channels during upsampling.
|
||||
non_linearity (str): Type of non-linearity to use. Default is "silu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
spatial_compression_ratio: int,
|
||||
upsample_match_channel: bool = True,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
if block_out_channels[0] % z_channels != 0:
|
||||
raise ValueError(
|
||||
f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
|
||||
)
|
||||
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.repeat = block_out_channels[0] // z_channels
|
||||
self.spatial_compression_ratio = spatial_compression_ratio
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# Middle blocks with attention
|
||||
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
|
||||
|
||||
# Upsampling blocks
|
||||
block_in_channel = block_out_channels[0]
|
||||
self.up_blocks = nn.ModuleList()
|
||||
for i in range(len(block_out_channels)):
|
||||
block_out_channel = block_out_channels[i]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
self.up_blocks.append(
|
||||
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
|
||||
)
|
||||
block_in_channel = block_out_channel
|
||||
|
||||
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
|
||||
if upsample_match_channel:
|
||||
block_out_channel = block_out_channels[i + 1]
|
||||
self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
|
||||
block_in_channel = block_out_channel
|
||||
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid_block, h)
|
||||
else:
|
||||
h = self.mid_block(h)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(up_block, h)
|
||||
else:
|
||||
h = up_block(h)
|
||||
h = self.norm_out(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model for 2D images with spatial tiling support.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
# fmt: off
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
latent_channels: int,
|
||||
block_out_channels: Tuple[int, ...],
|
||||
layers_per_block: int,
|
||||
spatial_compression_ratio: int,
|
||||
sample_size: int,
|
||||
scaling_factor: float = None,
|
||||
downsample_match_channel: bool = True,
|
||||
upsample_match_channel: bool = True,
|
||||
) -> None:
|
||||
# fmt: on
|
||||
super().__init__()
|
||||
|
||||
self.encoder = HunyuanImageEncoder2D(
|
||||
in_channels=in_channels,
|
||||
z_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
num_res_blocks=layers_per_block,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
downsample_match_channel=downsample_match_channel,
|
||||
)
|
||||
|
||||
self.decoder = HunyuanImageDecoder2D(
|
||||
z_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
num_res_blocks=layers_per_block,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
upsample_match_channel=upsample_match_channel,
|
||||
)
|
||||
|
||||
# Tiling and slicing configuration
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# Tiling parameters
|
||||
self.tile_sample_min_size = sample_size
|
||||
self.tile_latent_min_size = sample_size // spatial_compression_ratio
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_size: Optional[int] = None,
|
||||
tile_overlap_factor: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
|
||||
to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
|
||||
allow processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_size (`int`, *optional*):
|
||||
The minimum size required for a sample to be separated into tiles across the spatial dimension.
|
||||
tile_overlap_factor (`float`, *optional*):
|
||||
The overlap factor required for a latent to be separated into tiles across the spatial dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor):
|
||||
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
||||
|
||||
batch_size, num_channels, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode input using spatial tiling strategy.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded images.
|
||||
"""
|
||||
_, _, _, height, width = x.shape
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_size):
|
||||
row = []
|
||||
for j in range(0, width, overlap_size):
|
||||
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
moments = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return moments
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode latent using spatial tiling strategy.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
_, _, height, width = z.shape
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_size):
|
||||
row = []
|
||||
for j in range(0, width, overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
dec = torch.cat(result_rows, dim=-2)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
posterior = self.encode(sample).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
|
||||
return dec
|
||||
@@ -0,0 +1,934 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanImageRefinerCausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
bias: bool = True,
|
||||
pad_mode: str = "replicate",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
self.time_causal_padding = (
|
||||
kernel_size[0] // 2,
|
||||
kernel_size[0] // 2,
|
||||
kernel_size[1] // 2,
|
||||
kernel_size[1] // 2,
|
||||
kernel_size[2] - 1,
|
||||
0,
|
||||
)
|
||||
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class HunyuanImageRefinerRMS_norm(nn.Module):
|
||||
r"""
|
||||
A custom RMS normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The number of dimensions to normalize over.
|
||||
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||
Default is True.
|
||||
images (bool, optional): Whether the input represents image data. Default is True.
|
||||
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanImageRefinerAttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
|
||||
|
||||
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = x
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
batch_size, channels, frames, height, width = query.shape
|
||||
|
||||
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
|
||||
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
|
||||
|
||||
# batch_size, 1, frames * height * width, channels
|
||||
|
||||
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
class HunyuanImageRefinerUpsampleDCAE(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
|
||||
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_upsample = add_temporal_upsample
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
@staticmethod
|
||||
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
|
||||
"""
|
||||
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
|
||||
|
||||
Args:
|
||||
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
|
||||
r1: temporal upsampling factor
|
||||
r2: height upsampling factor
|
||||
r3: width upsampling factor
|
||||
"""
|
||||
b, packed_c, f, h, w = tensor.shape
|
||||
factor = r1 * r2 * r3
|
||||
c = packed_c // factor
|
||||
|
||||
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
|
||||
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
r1 = 2 if self.add_temporal_upsample else 1
|
||||
h = self.conv(x)
|
||||
if self.add_temporal_upsample:
|
||||
h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
|
||||
h = h[:, : h.shape[1] // 2]
|
||||
|
||||
# shortcut computation
|
||||
shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
|
||||
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
|
||||
|
||||
else:
|
||||
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
|
||||
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanImageRefinerDownsampleDCAE(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
|
||||
assert out_channels % factor == 0
|
||||
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_downsample = add_temporal_downsample
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
@staticmethod
|
||||
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
|
||||
"""
|
||||
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
|
||||
|
||||
This packs spatial/temporal dimensions into channels (opposite of upsample)
|
||||
"""
|
||||
b, c, packed_f, packed_h, packed_w = tensor.shape
|
||||
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
|
||||
|
||||
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
|
||||
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
r1 = 2 if self.add_temporal_downsample else 1
|
||||
h = self.conv(x)
|
||||
if self.add_temporal_downsample:
|
||||
# h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
|
||||
h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
|
||||
h = torch.cat([h, h], dim=1)
|
||||
# shortcut computation
|
||||
# shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
|
||||
shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
|
||||
B, C, T, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
|
||||
else:
|
||||
# h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
|
||||
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
|
||||
# shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
|
||||
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
|
||||
B, C, T, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
|
||||
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanImageRefinerResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
non_linearity: str = "swish",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
|
||||
self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
|
||||
|
||||
self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
|
||||
self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
|
||||
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class HunyuanImageRefinerMidBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_layers: int = 1,
|
||||
add_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.add_attention = add_attention
|
||||
|
||||
# There is always at least one resnet
|
||||
resnets = [
|
||||
HunyuanImageRefinerResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
HunyuanImageRefinerResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRefinerDownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
downsample_out_channels: Optional[int] = None,
|
||||
add_temporal_downsample: int = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
HunyuanImageRefinerResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if downsample_out_channels is not None:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageRefinerDownsampleDCAE(
|
||||
out_channels,
|
||||
out_channels=downsample_out_channels,
|
||||
add_temporal_downsample=add_temporal_downsample,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRefinerUpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
upsample_out_channels: Optional[int] = None,
|
||||
add_temporal_upsample: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
HunyuanImageRefinerResnetBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if upsample_out_channels is not None:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageRefinerUpsampleDCAE(
|
||||
out_channels,
|
||||
out_channels=upsample_out_channels,
|
||||
add_temporal_upsample=add_temporal_upsample,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
||||
|
||||
else:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRefinerEncoder3D(nn.Module):
|
||||
r"""
|
||||
3D vae encoder for HunyuanImageRefiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 64,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
temporal_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 16,
|
||||
downsample_match_channel: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.group_size = block_out_channels[-1] // self.out_channels
|
||||
|
||||
self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
input_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
|
||||
output_channel = block_out_channels[i]
|
||||
if not add_spatial_downsample:
|
||||
down_block = HunyuanImageRefinerDownBlock3D(
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
downsample_out_channels=None,
|
||||
add_temporal_downsample=False,
|
||||
)
|
||||
input_channel = output_channel
|
||||
else:
|
||||
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
|
||||
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
|
||||
down_block = HunyuanImageRefinerDownBlock3D(
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
downsample_out_channels=downsample_out_channels,
|
||||
add_temporal_downsample=add_temporal_downsample,
|
||||
)
|
||||
input_channel = downsample_out_channels
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
|
||||
|
||||
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
||||
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
else:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
|
||||
batch_size, _, frame, height, width = hidden_states.shape
|
||||
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
hidden_states += short_cut
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRefinerDecoder3D(nn.Module):
|
||||
r"""
|
||||
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 32,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
upsample_match_channel: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.repeat = block_out_channels[0] // self.in_channels
|
||||
|
||||
self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
|
||||
|
||||
# up
|
||||
input_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
|
||||
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
|
||||
if add_spatial_upsample or add_temporal_upsample:
|
||||
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
|
||||
up_block = HunyuanImageRefinerUpBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
upsample_out_channels=upsample_out_channels,
|
||||
add_temporal_upsample=add_temporal_upsample,
|
||||
)
|
||||
input_channel = upsample_out_channels
|
||||
else:
|
||||
up_block = HunyuanImageRefinerUpBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
upsample_out_channels=None,
|
||||
add_temporal_upsample=False,
|
||||
)
|
||||
input_channel = output_channel
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# out
|
||||
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
# post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanImage-2.1 Refiner.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
downsample_match_channel: bool = True,
|
||||
upsample_match_channel: bool = True,
|
||||
scaling_factor: float = 1.03682,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = HunyuanImageRefinerEncoder3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels * 2,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
downsample_match_channel=downsample_match_channel,
|
||||
)
|
||||
|
||||
self.decoder = HunyuanImageRefinerDecoder3D(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
layers_per_block=layers_per_block,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
upsample_match_channel=upsample_match_channel,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = spatial_compression_ratio
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
self.use_slicing = False
|
||||
|
||||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
||||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
||||
# intermediate tiles together, the memory requirement can be lowered.
|
||||
self.use_tiling = False
|
||||
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 256
|
||||
self.tile_sample_min_width = 256
|
||||
|
||||
# The minimal distance between two spatial tiles
|
||||
self.tile_sample_stride_height = 192
|
||||
self.tile_sample_stride_width = 192
|
||||
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_sample_stride_height: Optional[float] = None,
|
||||
tile_sample_stride_width: Optional[float] = None,
|
||||
tile_overlap_factor: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_sample_stride_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension.
|
||||
tile_sample_stride_width (`int`, *optional*):
|
||||
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
||||
artifacts produced across the width dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
x = self.encoder(x)
|
||||
return x
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
return dec
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
||||
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
||||
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
|
||||
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
|
||||
row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
|
||||
row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
tile = x[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
moments = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return moments
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
|
||||
_, _, _, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
||||
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
||||
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
|
||||
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
|
||||
row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
|
||||
row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
tile = z[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i : i + tile_latent_min_height,
|
||||
j : j + tile_latent_min_width,
|
||||
]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
dec = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return dec
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
return dec
|
||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
||||
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
|
||||
971
src/diffusers/models/transformers/transformer_hunyuanimage.py
Normal file
971
src/diffusers/models/transformers/transformer_hunyuanimage.py
Normal file
@@ -0,0 +1,971 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanImageAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"HunyuanImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)) # batch_size, seq_len, heads, head_dim
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
query = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(
|
||||
query[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1
|
||||
),
|
||||
query[:, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
key = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(key[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1),
|
||||
key[:, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# 4. Encoder condition QKV projection and normalization
|
||||
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=1)
|
||||
key = torch.cat([key, encoder_key], dim=1)
|
||||
value = torch.cat([value, encoder_value], dim=1)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 6. Output projection
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if getattr(attn, "to_add_out", None) is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanImagePatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Union[Tuple[int, int], Tuple[int, int, int]] = (16, 16),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
if len(patch_size) == 2:
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
elif len(patch_size) == 3:
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
else:
|
||||
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {len(patch_size)}")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageByT5TextProjection(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_size: int, out_features: int):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(in_features)
|
||||
self.linear_1 = nn.Linear(in_features, hidden_size)
|
||||
self.linear_2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear_3 = nn.Linear(hidden_size, out_features)
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm(encoder_hidden_states)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_3(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageAdaNorm(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or 2 * in_features
|
||||
self.linear = nn.Linear(in_features, out_features)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
def forward(
|
||||
self, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
temb = self.linear(self.nonlinearity(temb))
|
||||
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
||||
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
||||
return gate_msa, gate_mlp
|
||||
|
||||
|
||||
class HunyuanImageCombinedTimeGuidanceEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
guidance_embeds: bool = False,
|
||||
use_meanflow: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.use_meanflow = use_meanflow
|
||||
|
||||
self.time_proj_r = None
|
||||
self.timestep_embedder_r = None
|
||||
if use_meanflow:
|
||||
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.guidance_embedder = None
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
timestep_r: Optional[torch.Tensor] = None,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
|
||||
|
||||
if timestep_r is not None:
|
||||
timesteps_proj_r = self.time_proj_r(timestep_r)
|
||||
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
|
||||
timesteps_emb = (timesteps_emb + timesteps_emb_r) / 2
|
||||
|
||||
if self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=timestep.dtype))
|
||||
conditioning = timesteps_emb + guidance_emb
|
||||
else:
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
# IndividualTokenRefinerBlock
|
||||
@maybe_allow_in_graph
|
||||
class HunyuanImageIndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int, # 28
|
||||
attention_head_dim: int, # 128
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
||||
|
||||
self.norm_out = HunyuanImageAdaNorm(hidden_size, 2 * hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
gate_msa, gate_mlp = self.norm_out(temb)
|
||||
hidden_states = hidden_states + attn_output * gate_msa
|
||||
|
||||
ff_output = self.ff(self.norm2(hidden_states))
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageIndividualTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.refiner_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageIndividualTokenRefinerBlock(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self_attn_mask = None
|
||||
if attention_mask is not None:
|
||||
batch_size = attention_mask.shape[0]
|
||||
seq_len = attention_mask.shape[1]
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
self_attn_mask[:, :, :, 0] = True
|
||||
|
||||
for block in self.refiner_blocks:
|
||||
hidden_states = block(hidden_states, temb, self_attn_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# txt_in
|
||||
class HunyuanImageTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=hidden_size, pooled_projection_dim=in_channels
|
||||
)
|
||||
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.token_refiner = HunyuanImageIndividualTokenRefiner(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_layers=num_layers,
|
||||
mlp_width_ratio=mlp_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is None:
|
||||
pooled_hidden_states = hidden_states.mean(dim=1)
|
||||
else:
|
||||
original_dtype = hidden_states.dtype
|
||||
mask_float = attention_mask.float().unsqueeze(-1)
|
||||
pooled_hidden_states = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
pooled_hidden_states = pooled_hidden_states.to(original_dtype)
|
||||
|
||||
temb = self.time_text_embed(timestep, pooled_hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(patch_size, (tuple, list)) or len(patch_size) not in [2, 3]:
|
||||
raise ValueError(f"patch_size must be a tuple or list of length 2 or 3, got {patch_size}")
|
||||
|
||||
if not isinstance(rope_dim, (tuple, list)) or len(rope_dim) not in [2, 3]:
|
||||
raise ValueError(f"rope_dim must be a tuple or list of length 2 or 3, got {rope_dim}")
|
||||
|
||||
if not len(patch_size) == len(rope_dim):
|
||||
raise ValueError(f"patch_size and rope_dim must have the same length, got {patch_size} and {rope_dim}")
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.rope_dim = rope_dim
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if hidden_states.ndim == 5:
|
||||
_, _, frame, height, width = hidden_states.shape
|
||||
patch_size_frame, patch_size_height, patch_size_width = self.patch_size
|
||||
rope_sizes = [frame // patch_size_frame, height // patch_size_height, width // patch_size_width]
|
||||
elif hidden_states.ndim == 4:
|
||||
_, _, height, width = hidden_states.shape
|
||||
patch_size_height, patch_size_width = self.patch_size
|
||||
rope_sizes = [height // patch_size_height, width // patch_size_width]
|
||||
else:
|
||||
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
|
||||
|
||||
axes_grids = []
|
||||
for i in range(len(rope_sizes)):
|
||||
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
||||
axes_grids.append(grid)
|
||||
grid = torch.meshgrid(*axes_grids, indexing="ij") # dim x [H, W]
|
||||
grid = torch.stack(grid, dim=0) # [2, H, W]
|
||||
|
||||
freqs = []
|
||||
for i in range(len(rope_sizes)):
|
||||
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
||||
freqs.append(freq)
|
||||
|
||||
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HunyuanImageSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
mlp_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
bias=True,
|
||||
processor=HunyuanImageAttnProcessor(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
||||
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
|
||||
norm_hidden_states, norm_encoder_hidden_states = (
|
||||
norm_hidden_states[:, :-text_seq_length, :],
|
||||
norm_hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
|
||||
# 2. Attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, :-text_seq_length, :],
|
||||
hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HunyuanImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=hidden_size,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=HunyuanImageAttnProcessor(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of dual-stream blocks to use.
|
||||
num_single_layers (`int`, defaults to `40`):
|
||||
The number of layers of single-stream blocks to use.
|
||||
num_refiner_layers (`int`, defaults to `2`):
|
||||
The number of layers of refiner blocks to use.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
The ratio of the hidden layer size to the input size in the feedforward network.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the spatial patches to use in the patch embedding layer.
|
||||
patch_size_t (`int`, defaults to `1`):
|
||||
The size of the tmeporal patches to use in the patch embedding layer.
|
||||
qk_norm (`str`, defaults to `rms_norm`):
|
||||
The normalization to use for the query and key projections in the attention layers.
|
||||
guidance_embeds (`bool`, defaults to `True`):
|
||||
Whether to use guidance embeddings in the model.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The dimension of the pooled projection of the text embeddings.
|
||||
rope_theta (`float`, defaults to `256.0`):
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
image_condition_type (`str`, *optional*, defaults to `None`):
|
||||
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
|
||||
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
|
||||
tokens in the latent stream and apply conditioning.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
||||
_no_split_modules = [
|
||||
"HunyuanImageTransformerBlock",
|
||||
"HunyuanImageSingleTransformerBlock",
|
||||
"HunyuanImagePatchEmbed",
|
||||
"HunyuanImageTokenRefiner",
|
||||
]
|
||||
_repeated_blocks = [
|
||||
"HunyuanImageTransformerBlock",
|
||||
"HunyuanImageSingleTransformerBlock",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 64,
|
||||
out_channels: int = 64,
|
||||
num_attention_heads: int = 28,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 20,
|
||||
num_single_layers: int = 40,
|
||||
num_refiner_layers: int = 2,
|
||||
mlp_ratio: float = 4.0,
|
||||
patch_size: Tuple[int, int] = (1, 1),
|
||||
qk_norm: str = "rms_norm",
|
||||
guidance_embeds: bool = False,
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_2_dim: Optional[int] = None,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (64, 64),
|
||||
use_meanflow: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if not (isinstance(patch_size, (tuple, list)) and len(patch_size) in [2, 3]):
|
||||
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {patch_size}")
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Latent and condition embedders
|
||||
self.x_embedder = HunyuanImagePatchEmbed(patch_size, in_channels, inner_dim)
|
||||
self.context_embedder = HunyuanImageTokenRefiner(
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
|
||||
if text_embed_2_dim is not None:
|
||||
self.context_embedder_2 = HunyuanImageByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
|
||||
else:
|
||||
self.context_embedder_2 = None
|
||||
|
||||
self.time_guidance_embed = HunyuanImageCombinedTimeGuidanceEmbedding(inner_dim, guidance_embeds, use_meanflow)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanImageRotaryPosEmbed(patch_size, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Single stream transformer blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
timestep_r: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask_2: Optional[torch.Tensor] = None,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
if hidden_states.ndim == 4:
|
||||
batch_size, channels, height, width = hidden_states.shape
|
||||
sizes = (height, width)
|
||||
elif hidden_states.ndim == 5:
|
||||
batch_size, channels, frame, height, width = hidden_states.shape
|
||||
sizes = (frame, height, width)
|
||||
else:
|
||||
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
|
||||
|
||||
post_patch_sizes = tuple(d // p for d, p in zip(sizes, self.config.patch_size))
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
encoder_attention_mask = encoder_attention_mask.bool()
|
||||
temb = self.time_guidance_embed(timestep, guidance=guidance, timestep_r=timestep_r)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
|
||||
if self.context_embedder_2 is not None and encoder_hidden_states_2 is not None:
|
||||
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
|
||||
|
||||
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
|
||||
|
||||
# reorder and combine text tokens: combine valid tokens first, then padding
|
||||
new_encoder_hidden_states = []
|
||||
new_encoder_attention_mask = []
|
||||
|
||||
for text, text_mask, text_2, text_mask_2 in zip(
|
||||
encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2
|
||||
):
|
||||
# Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
|
||||
new_encoder_hidden_states.append(
|
||||
torch.cat(
|
||||
[
|
||||
text_2[text_mask_2], # valid byt5
|
||||
text[text_mask], # valid mllm
|
||||
text_2[~text_mask_2], # invalid byt5
|
||||
text[~text_mask], # invalid mllm
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply same reordering to attention masks
|
||||
new_encoder_attention_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
text_mask_2[text_mask_2],
|
||||
text_mask[text_mask],
|
||||
text_mask_2[~text_mask_2],
|
||||
text_mask[~text_mask],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
|
||||
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
|
||||
|
||||
attention_mask = torch.nn.functional.pad(encoder_attention_mask, (hidden_states.shape[1], 0), value=True)
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
# 3. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# 4. Output projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. unpatchify
|
||||
# reshape: [batch_size, *post_patch_dims, channels, *patch_size]
|
||||
out_channels = self.config.out_channels
|
||||
reshape_dims = [batch_size] + list(post_patch_sizes) + [out_channels] + list(self.config.patch_size)
|
||||
hidden_states = hidden_states.reshape(*reshape_dims)
|
||||
|
||||
# create permutation pattern: batch, channels, then interleave post_patch and patch dims
|
||||
# For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
|
||||
# For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
|
||||
ndim = len(post_patch_sizes)
|
||||
permute_pattern = [0, ndim + 1] # batch, channels
|
||||
for i in range(ndim):
|
||||
permute_pattern.extend([i + 1, ndim + 2 + i]) # post_patch_sizes[i], patch_sizes[i]
|
||||
hidden_states = hidden_states.permute(*permute_pattern)
|
||||
|
||||
# flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
|
||||
# batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
|
||||
final_dims = [batch_size, out_channels] + [
|
||||
post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
|
||||
]
|
||||
hidden_states = hidden_states.reshape(*final_dims)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
@@ -130,8 +130,14 @@ class PipelineState:
|
||||
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
|
||||
intermediates dict.
|
||||
"""
|
||||
if name in self.values:
|
||||
return self.values[name]
|
||||
# Use object.__getattribute__ to avoid infinite recursion during deepcopy
|
||||
try:
|
||||
values = object.__getattribute__(self, "values")
|
||||
except AttributeError:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
if name in values:
|
||||
return values[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __repr__(self):
|
||||
@@ -2492,6 +2498,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
else:
|
||||
state = deepcopy(state)
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
@@ -238,19 +238,27 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
|
||||
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"encoder_hidden_states_mask": (
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
|
||||
# YiYi TODO: add cache context
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
@@ -328,19 +336,27 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
|
||||
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"encoder_hidden_states_mask": (
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
|
||||
# YiYi TODO: add cache context
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
|
||||
@@ -201,27 +201,41 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
) -> PipelineState:
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_input_fields = {
|
||||
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"time_ids": ("add_time_ids", "negative_add_time_ids"),
|
||||
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
guider_inputs = {
|
||||
"prompt_embeds": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"time_ids": (
|
||||
getattr(block_state, "add_time_ids", None),
|
||||
getattr(block_state, "negative_add_time_ids", None),
|
||||
),
|
||||
"text_embeds": (
|
||||
getattr(block_state, "pooled_prompt_embeds", None),
|
||||
getattr(block_state, "negative_pooled_prompt_embeds", None),
|
||||
),
|
||||
"image_embeds": (
|
||||
getattr(block_state, "ip_adapter_embeds", None),
|
||||
getattr(block_state, "negative_ip_adapter_embeds", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.unet)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||
|
||||
# Predict the noise residual
|
||||
@@ -344,11 +358,23 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_input_fields = {
|
||||
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"time_ids": ("add_time_ids", "negative_add_time_ids"),
|
||||
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
guider_inputs = {
|
||||
"prompt_embeds": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"time_ids": (
|
||||
getattr(block_state, "add_time_ids", None),
|
||||
getattr(block_state, "negative_add_time_ids", None),
|
||||
),
|
||||
"text_embeds": (
|
||||
getattr(block_state, "pooled_prompt_embeds", None),
|
||||
getattr(block_state, "negative_pooled_prompt_embeds", None),
|
||||
),
|
||||
"image_embeds": (
|
||||
getattr(block_state, "ip_adapter_embeds", None),
|
||||
getattr(block_state, "negative_ip_adapter_embeds", None),
|
||||
),
|
||||
}
|
||||
|
||||
# cond_scale for the timestep (controlnet input)
|
||||
@@ -369,12 +395,15 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
# guided denoiser step
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
|
||||
@@ -94,25 +94,30 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
) -> PipelineState:
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_input_fields = {
|
||||
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
guider_inputs = {
|
||||
"prompt_embeds": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
}
|
||||
transformer_dtype = components.transformer.dtype
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||
|
||||
# Predict the noise residual
|
||||
|
||||
@@ -241,6 +241,7 @@ else:
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
]
|
||||
_import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
@@ -640,6 +641,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
from .hunyuan_video import (
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoFramepackPipeline,
|
||||
|
||||
50
src/diffusers/pipelines/hunyuan_image/__init__.py
Normal file
50
src/diffusers/pipelines/hunyuan_image/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"]
|
||||
_import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_hunyuanimage import HunyuanImagePipeline
|
||||
from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
866
src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py
Normal file
866
src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py
Normal file
@@ -0,0 +1,866 @@
|
||||
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
|
||||
|
||||
from ...guiders import AdaptiveProjectedMixGuidance
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import HunyuanImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import HunyuanImagePipeline
|
||||
|
||||
>>> pipe = HunyuanImagePipeline.from_pretrained(
|
||||
... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0]
|
||||
>>> image.save("hunyuanimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def extract_glyph_text(prompt: str):
|
||||
"""
|
||||
Extract text enclosed in quotes for glyph rendering.
|
||||
|
||||
Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
|
||||
|
||||
Args:
|
||||
prompt: Input text prompt
|
||||
|
||||
Returns:
|
||||
Formatted glyph text string or None if no quoted text found
|
||||
"""
|
||||
text_prompt_texts = []
|
||||
pattern_quote_single = r"\'(.*?)\'"
|
||||
pattern_quote_double = r"\"(.*?)\""
|
||||
pattern_quote_chinese_single = r"‘(.*?)’"
|
||||
pattern_quote_chinese_double = r"“(.*?)”"
|
||||
|
||||
matches_quote_single = re.findall(pattern_quote_single, prompt)
|
||||
matches_quote_double = re.findall(pattern_quote_double, prompt)
|
||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
|
||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
|
||||
|
||||
text_prompt_texts.extend(matches_quote_single)
|
||||
text_prompt_texts.extend(matches_quote_double)
|
||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||
|
||||
if text_prompt_texts:
|
||||
glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". "
|
||||
else:
|
||||
glyph_text_formatted = None
|
||||
|
||||
return glyph_text_formatted
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HunyuanImagePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
The HunyuanImage pipeline for text-to-image generation.
|
||||
|
||||
Args:
|
||||
transformer ([`HunyuanImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLHunyuanImage`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
|
||||
text_encoder_2 ([`T5EncoderModel`]):
|
||||
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
|
||||
variant.
|
||||
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
|
||||
guider ([`AdaptiveProjectedMixGuidance`]):
|
||||
[AdaptiveProjectedMixGuidance]to be used to guide the image generation.
|
||||
ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*):
|
||||
[AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
_optional_components = ["ocr_guider", "guider"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLHunyuanImage,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: ByT5Tokenizer,
|
||||
transformer: HunyuanImageTransformer2DModel,
|
||||
guider: Optional[AdaptiveProjectedMixGuidance] = None,
|
||||
ocr_guider: Optional[AdaptiveProjectedMixGuidance] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
guider=guider,
|
||||
ocr_guider=ocr_guider,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = 1000
|
||||
self.tokenizer_2_max_length = 128
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
self.prompt_template_encode_start_idx = 34
|
||||
self.default_sample_size = 64
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tokenizer_max_length: int = 1000,
|
||||
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
|
||||
drop_idx: int = 34,
|
||||
hidden_state_skip_layer: int = 2,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = tokenizer(
|
||||
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
encoder_hidden_states = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
|
||||
prompt_embeds = prompt_embeds[:, drop_idx:]
|
||||
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
encoder_attention_mask = encoder_attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def _get_byt5_prompt_embeds(
|
||||
self,
|
||||
tokenizer: ByT5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: str,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tokenizer_max_length: int = 128,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or text_encoder.dtype
|
||||
|
||||
if isinstance(prompt, list):
|
||||
raise ValueError("byt5 prompt should be a string")
|
||||
elif prompt is None:
|
||||
raise ValueError("byt5 prompt should not be None")
|
||||
|
||||
txt_tokens = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask.float(),
|
||||
)[0]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
encoder_attention_mask = txt_tokens.attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
batch_size: int = 1,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
batch_size (`int`):
|
||||
batch size of prompts, defaults to 1
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = [""] * batch_size
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_max_length,
|
||||
template=self.prompt_template_encode,
|
||||
drop_idx=self.prompt_template_encode_start_idx,
|
||||
)
|
||||
|
||||
if prompt_embeds_2 is None:
|
||||
prompt_embeds_2_list = []
|
||||
prompt_embeds_mask_2_list = []
|
||||
|
||||
glyph_texts = [extract_glyph_text(p) for p in prompt]
|
||||
for glyph_text in glyph_texts:
|
||||
if glyph_text is None:
|
||||
glyph_text_embeds = torch.zeros(
|
||||
(1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device
|
||||
)
|
||||
glyph_text_embeds_mask = torch.zeros(
|
||||
(1, self.tokenizer_2_max_length), device=device, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds(
|
||||
tokenizer=self.tokenizer_2,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prompt=glyph_text,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_2_max_length,
|
||||
)
|
||||
|
||||
prompt_embeds_2_list.append(glyph_text_embeds)
|
||||
prompt_embeds_mask_2_list.append(glyph_text_embeds_mask)
|
||||
|
||||
prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0)
|
||||
prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
_, seq_len_2, _ = prompt_embeds_2.shape
|
||||
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
prompt_embeds_2=None,
|
||||
prompt_embeds_mask_2=None,
|
||||
negative_prompt_embeds_2=None,
|
||||
negative_prompt_embeds_mask_2=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if prompt is None and prompt_embeds_2 is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
||||
)
|
||||
|
||||
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
|
||||
)
|
||||
if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
distilled_guidance_scale: Optional[float] = 3.25,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
|
||||
not provided, will use an empty negative prompt. Ignored when not using guidance. ).
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
distilled_guidance_scale (`float`, *optional*, defaults to None):
|
||||
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
||||
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
||||
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
||||
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
|
||||
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
|
||||
ignored.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, text embeddings mask will be generated from `prompt` input argument.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt`
|
||||
input argument.
|
||||
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt`
|
||||
input argument.
|
||||
negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from
|
||||
`negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds_2=prompt_embeds_2,
|
||||
prompt_embeds_mask_2=prompt_embeds_mask_2,
|
||||
negative_prompt_embeds_2=negative_prompt_embeds_2,
|
||||
negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare prompt embeds
|
||||
|
||||
prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds_2=prompt_embeds_2,
|
||||
prompt_embeds_mask_2=prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
|
||||
prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
|
||||
|
||||
# select guider
|
||||
if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None:
|
||||
# prompt contains ocr and pipeline has a guider for ocr
|
||||
guider = self.ocr_guider
|
||||
elif self.guider is not None:
|
||||
guider = self.guider
|
||||
# distilled model does not use guidance method, use default guider with enabled=False
|
||||
else:
|
||||
guider = AdaptiveProjectedMixGuidance(enabled=False)
|
||||
|
||||
if guider._enabled and guider.num_conditions > 1:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask,
|
||||
negative_prompt_embeds_2,
|
||||
negative_prompt_embeds_mask_2,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds_2=negative_prompt_embeds_2,
|
||||
prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
|
||||
negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance (for guidance-distilled model)
|
||||
if self.transformer.config.guidance_embeds and distilled_guidance_scale is None:
|
||||
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
|
||||
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = (
|
||||
torch.tensor(
|
||||
[distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
|
||||
)
|
||||
* 1000.0
|
||||
)
|
||||
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if self.transformer.config.use_meanflow:
|
||||
if i == len(timesteps) - 1:
|
||||
timestep_r = torch.tensor([0.0], device=device)
|
||||
else:
|
||||
timestep_r = timesteps[i + 1]
|
||||
timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
|
||||
else:
|
||||
timestep_r = None
|
||||
|
||||
# Step 1: Collect model inputs needed for the guidance method
|
||||
# conditional inputs should always be first element in the tuple
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
||||
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
|
||||
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
|
||||
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
|
||||
}
|
||||
|
||||
# Step 2: Update guider's internal state for this denoising step
|
||||
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
|
||||
# Step 3: Prepare batched model inputs based on the guidance method
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = guider.prepare_inputs(guider_inputs)
|
||||
# Step 4: Run the denoiser for each batch
|
||||
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
|
||||
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
|
||||
for guider_state_batch in guider_state:
|
||||
guider.prepare_models(self.transformer)
|
||||
|
||||
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
|
||||
cond_kwargs = {
|
||||
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
|
||||
}
|
||||
|
||||
# e.g. "pred_cond"/"pred_uncond"
|
||||
context_name = getattr(guider_state_batch, guider._identifier_key)
|
||||
with self.transformer.cache_context(context_name):
|
||||
# Run denoiser and store noise prediction in this batch
|
||||
guider_state_batch.noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep,
|
||||
timestep_r=timestep_r,
|
||||
guidance=guidance,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
|
||||
# Cleanup model (e.g., remove hooks)
|
||||
guider.cleanup_models(self.transformer)
|
||||
|
||||
# Step 5: Combine predictions using the guidance method
|
||||
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
|
||||
# Continuing the CFG example, the guider receives:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
|
||||
# ]
|
||||
# And extracts predictions using the __guidance_identifier__:
|
||||
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
|
||||
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
|
||||
# Then applies CFG formula:
|
||||
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
noise_pred = guider(guider_state)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return HunyuanImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,752 @@
|
||||
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from ...guiders import AdaptiveProjectedMixGuidance
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import HunyuanImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import HunyuanImageRefinerPipeline
|
||||
|
||||
>>> pipe = HunyuanImageRefinerPipeline.from_pretrained(
|
||||
... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = load_image("path/to/image.png")
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(prompt, image=image, num_inference_steps=4).images[0]
|
||||
>>> image.save("hunyuanimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class HunyuanImageRefinerPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
The HunyuanImage pipeline for text-to-image generation.
|
||||
|
||||
Args:
|
||||
transformer ([`HunyuanImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLHunyuanImageRefiner`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
_optional_components = ["guider"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLHunyuanImageRefiner,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
transformer: HunyuanImageTransformer2DModel,
|
||||
guider: Optional[AdaptiveProjectedMixGuidance] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
guider=guider,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = 256
|
||||
self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
self.prompt_template_encode_start_idx = 36
|
||||
self.default_sample_size = 64
|
||||
self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64
|
||||
|
||||
# Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tokenizer_max_length: int = 1000,
|
||||
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
|
||||
drop_idx: int = 34,
|
||||
hidden_state_skip_layer: int = 2,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = tokenizer(
|
||||
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
encoder_hidden_states = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
|
||||
prompt_embeds = prompt_embeds[:, drop_idx:]
|
||||
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
encoder_attention_mask = encoder_attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
batch_size: int = 1,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
batch_size (`int`):
|
||||
batch size of prompts, defaults to 1
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = [""] * batch_size
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_max_length,
|
||||
template=self.prompt_template_encode,
|
||||
drop_idx=self.prompt_template_encode_start_idx,
|
||||
)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image_latents,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
strength=0.25,
|
||||
):
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, 1, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
cond_latents = strength * noise + (1 - strength) * image_latents
|
||||
|
||||
return latents, cond_latents
|
||||
|
||||
@staticmethod
|
||||
def _reorder_image_tokens(image_latents):
|
||||
image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2)
|
||||
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
||||
image_latents = image_latents.reshape(
|
||||
batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width
|
||||
)
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
return image_latents
|
||||
|
||||
@staticmethod
|
||||
def _restore_image_tokens_order(latents):
|
||||
"""Restore image tokens order by splitting channels and removing first frame slice."""
|
||||
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W
|
||||
latents = latents.reshape(
|
||||
batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width
|
||||
) # B, F*2, C//2, H, W
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W
|
||||
# Remove first frame slice
|
||||
latents = latents[:, :, 1:]
|
||||
|
||||
return latents
|
||||
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample")
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample")
|
||||
image_latents = self._reorder_image_tokens(image_latents)
|
||||
|
||||
image_latents = image_latents * self.vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
distilled_guidance_scale: Optional[float] = 3.25,
|
||||
image: Optional[PipelineImageInput] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 4,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, will use an empty negative
|
||||
prompt. Ignored when not using guidance.
|
||||
distilled_guidance_scale (`float`, *optional*, defaults to None):
|
||||
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
||||
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
||||
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
||||
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
|
||||
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
|
||||
ignored.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. process image
|
||||
if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels:
|
||||
image_latents = image
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height, width)
|
||||
image = image.unsqueeze(2).to(device, dtype=self.vae.dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
|
||||
# 3.prepare prompt embeds
|
||||
|
||||
if self.guider is not None:
|
||||
guider = self.guider
|
||||
else:
|
||||
# distilled model does not use guidance method, use default guider with enabled=False
|
||||
guider = AdaptiveProjectedMixGuidance(enabled=False)
|
||||
|
||||
requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
|
||||
|
||||
if requires_unconditional_embeds:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latents, cond_latents = self.prepare_latents(
|
||||
image_latents=image_latents,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_channels_latents=self.latent_channels,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance (this pipeline only supports guidance-distilled models)
|
||||
if distilled_guidance_scale is None:
|
||||
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
|
||||
guidance = (
|
||||
torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device)
|
||||
* 1000.0
|
||||
)
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype)
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
# Step 1: Collect model inputs needed for the guidance method
|
||||
# conditional inputs should always be first element in the tuple
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
||||
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
|
||||
}
|
||||
|
||||
# Step 2: Update guider's internal state for this denoising step
|
||||
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
|
||||
# Step 3: Prepare batched model inputs based on the guidance method
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# Step 4: Run the denoiser for each batch
|
||||
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
|
||||
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
|
||||
for guider_state_batch in guider_state:
|
||||
guider.prepare_models(self.transformer)
|
||||
|
||||
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
|
||||
cond_kwargs = {
|
||||
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
|
||||
}
|
||||
|
||||
# e.g. "pred_cond"/"pred_uncond"
|
||||
context_name = getattr(guider_state_batch, guider._identifier_key)
|
||||
with self.transformer.cache_context(context_name):
|
||||
# Run denoiser and store noise prediction in this batch
|
||||
guider_state_batch.noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
|
||||
# Cleanup model (e.g., remove hooks)
|
||||
guider.cleanup_models(self.transformer)
|
||||
|
||||
# Step 5: Combine predictions using the guidance method
|
||||
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
|
||||
# Continuing the CFG example, the guider receives:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
|
||||
# ]
|
||||
# And extracts predictions using the __guidance_identifier__:
|
||||
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
|
||||
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
|
||||
# Then applies CFG formula:
|
||||
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
noise_pred = guider(guider_state)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
latents = self._restore_image_tokens_order(latents)
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return HunyuanImagePipelineOutput(images=image)
|
||||
21
src/diffusers/pipelines/hunyuan_image/pipeline_output.py
Normal file
21
src/diffusers/pipelines/hunyuan_image/pipeline_output.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunyuanImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for HunyuanImage pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -76,6 +76,7 @@ LOADABLE_CLASSES = {
|
||||
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||
"BaseGuidance": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
|
||||
@@ -17,6 +17,21 @@ class AdaptiveProjectedGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AdaptiveProjectedMixGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -32,6 +47,21 @@ class AutoGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BaseGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -378,6 +408,36 @@ class AutoencoderKLCosmos(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImage(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImageRefiner(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -858,6 +918,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1037,6 +1037,36 @@ class HunyuanDiTPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanImageRefinerPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/hunyuan_image_21/__init__.py
Normal file
0
tests/pipelines/hunyuan_image_21/__init__.py
Normal file
290
tests/pipelines/hunyuan_image_21/test_hunyuanimage.py
Normal file
290
tests/pipelines/hunyuan_image_21/test_hunyuanimage.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ByT5Tokenizer,
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2Tokenizer,
|
||||
T5Config,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AdaptiveProjectedMixGuidance,
|
||||
AutoencoderKLHunyuanImage,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanImagePipeline,
|
||||
HunyuanImageTransformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..test_pipelines_common import (
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
to_np,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanImagePipelineFastTests(
|
||||
PipelineTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = HunyuanImagePipeline
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
test_attention_slicing = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1, guidance_embeds: bool = False):
|
||||
torch.manual_seed(0)
|
||||
transformer = HunyuanImageTransformer2DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
num_refiner_layers=1,
|
||||
patch_size=(1, 1),
|
||||
guidance_embeds=guidance_embeds,
|
||||
text_embed_dim=32,
|
||||
text_embed_2_dim=32,
|
||||
rope_axes_dim=(4, 4),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLHunyuanImage(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
block_out_channels=(32, 64, 64, 64),
|
||||
layers_per_block=1,
|
||||
scaling_factor=0.476986,
|
||||
spatial_compression_ratio=8,
|
||||
sample_size=128,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
|
||||
if not guidance_embeds:
|
||||
torch.manual_seed(0)
|
||||
guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=2)
|
||||
ocr_guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=3)
|
||||
else:
|
||||
guider = None
|
||||
ocr_guider = None
|
||||
torch.manual_seed(0)
|
||||
config = Qwen2_5_VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [2, 2, 4],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 32,
|
||||
"num_heads": 2,
|
||||
"out_hidden_size": 32,
|
||||
},
|
||||
hidden_size=32,
|
||||
vocab_size=152064,
|
||||
vision_end_token_id=151653,
|
||||
vision_start_token_id=151652,
|
||||
vision_token_id=151654,
|
||||
)
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
torch.manual_seed(0)
|
||||
t5_config = T5Config(
|
||||
d_model=32,
|
||||
d_kv=4,
|
||||
d_ff=16,
|
||||
num_layers=2,
|
||||
num_heads=2,
|
||||
relative_attention_num_buckets=8,
|
||||
relative_attention_max_distance=32,
|
||||
vocab_size=256,
|
||||
feed_forward_proj="gated-gelu",
|
||||
dense_act_fn="gelu_new",
|
||||
is_encoder_decoder=False,
|
||||
use_cache=False,
|
||||
tie_word_embeddings=False,
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel(t5_config)
|
||||
tokenizer_2 = ByT5Tokenizer()
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"guider": guider,
|
||||
"ocr_guider": ocr_guider,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 5,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (3, 16, 16))
|
||||
|
||||
expected_slice_np = np.array(
|
||||
[0.6252659, 0.51482046, 0.60799813, 0.59267783, 0.488082, 0.5857634, 0.523781, 0.58028054, 0.5674121]
|
||||
)
|
||||
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
|
||||
|
||||
self.assertTrue(
|
||||
np.abs(output_slice - expected_slice_np).max() < 1e-3,
|
||||
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
|
||||
)
|
||||
|
||||
def test_inference_guider(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.guider = pipe.guider.new(guidance_scale=1000)
|
||||
pipe.ocr_guider = pipe.ocr_guider.new(guidance_scale=1000)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (3, 16, 16))
|
||||
|
||||
expected_slice_np = np.array(
|
||||
[0.61494756, 0.49616697, 0.60327923, 0.6115793, 0.49047345, 0.56977504, 0.53066164, 0.58880305, 0.5570612]
|
||||
)
|
||||
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
|
||||
|
||||
self.assertTrue(
|
||||
np.abs(output_slice - expected_slice_np).max() < 1e-3,
|
||||
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
|
||||
)
|
||||
|
||||
def test_inference_with_distilled_guidance(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components(guidance_embeds=True)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["distilled_guidance_scale"] = 3.5
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (3, 16, 16))
|
||||
|
||||
expected_slice_np = np.array(
|
||||
[0.63667065, 0.5187377, 0.66757566, 0.6320319, 0.4913387, 0.54813194, 0.5335031, 0.5736143, 0.5461346]
|
||||
)
|
||||
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
|
||||
|
||||
self.assertTrue(
|
||||
np.abs(output_slice - expected_slice_np).max() < 1e-3,
|
||||
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(tile_sample_min_size=96)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
@unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user