Compare commits

...

7 Commits

Author SHA1 Message Date
Dhruv Nair
cb8c90b683 update 2024-06-03 12:46:42 +00:00
Dhruv Nair
278c16aa05 update 2024-06-03 12:32:31 +00:00
Dhruv Nair
cfbfce4e90 update 2024-06-03 05:02:41 +00:00
Dhruv Nair
b8a61d6fb9 update 2024-06-03 04:45:30 +00:00
XCL
413604405f Tencent Hunyuan Team: add HunyuanDiT related updates (#8240)
* Hunyuan Team: add HunyuanDiT related updates


---------

Co-authored-by: XCLiu <liuxc1996@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
2024-06-01 12:41:21 -10:00
39th president of the United States, probably
bc108e1533 Fix DREAM training (#8302)
Co-authored-by: Jimmy <39@🇺🇸.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-06-01 11:27:57 +04:00
Anton Obukhov
86555c9f59 Fix marigold documentation (#8372)
* rename prs-eth/marigold-lcm-v1-0 into prs-eth/marigold-depth-lcm-v1-0

* update image paths in https://huggingface.co/datasets/huggingface/documentation-images to use main branch

* fix relative paths to other diffusers pages

* Update docs/source/en/using-diffusers/marigold_usage.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-05-31 12:10:05 -10:00
20 changed files with 2027 additions and 50 deletions

View File

@@ -59,7 +59,7 @@ jobs:
runs-on: [single-gpu, nvidia-gpu, t4, ci]
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3

View File

@@ -62,7 +62,7 @@ jobs:
runs-on: [single-gpu, nvidia-gpu, t4, ci]
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0 --privileged
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -71,12 +71,6 @@ jobs:
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Tailscale
uses: huggingface/tailscale-action@v1
with:
authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
@@ -95,18 +89,11 @@ jobs:
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Tailscale Wait
if: ${{ failure() || runner.debug == '1' }}
uses: huggingface/tailscale-action@v1
with:
waitForSSH: true
authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2

View File

@@ -25,7 +25,7 @@ jobs:
runs-on: [single-gpu, nvidia-gpu, "${{ github.event.inputs.runner_type }}", ci]
container:
image: ${{ github.event.inputs.docker_image }}
options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0 --privileged
steps:
- name: Checkout diffusers

View File

@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# Marigold Pipelines for Computer Vision Tasks
[Marigold](marigold) is a novel diffusion-based dense prediction approach, and a set of pipelines for various computer vision tasks, such as monocular depth estimation.
[Marigold](../api/pipelines/marigold) is a novel diffusion-based dense prediction approach, and a set of pipelines for various computer vision tasks, such as monocular depth estimation.
This guide will show you how to use Marigold to obtain fast and high-quality predictions for images and videos.
@@ -31,7 +31,7 @@ The original code can also be used to train new checkpoints.
| Checkpoint | Modality | Comment |
|-----------------------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [prs-eth/marigold-v1-0](https://huggingface.co/prs-eth/marigold-v1-0) | Depth | The first Marigold Depth checkpoint, which predicts *affine-invariant depth* maps. The performance of this checkpoint in benchmarks was studied in the original [paper](https://huggingface.co/papers/2312.02145). Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. Affine-invariant depth prediction has a range of values in each pixel between 0 (near plane) and 1 (far plane); both planes are chosen by the model as part of the inference process. See the `MarigoldImageProcessor` reference for visualization utilities. |
| [prs-eth/marigold-lcm-v1-0](https://huggingface.co/prs-eth/marigold-lcm-v1-0) | Depth | The fast Marigold Depth checkpoint, fine-tuned from `prs-eth/marigold-v1-0`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. |
| [prs-eth/marigold-depth-lcm-v1-0](https://huggingface.co/prs-eth/marigold-depth-lcm-v1-0) | Depth | The fast Marigold Depth checkpoint, fine-tuned from `prs-eth/marigold-v1-0`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. |
| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | A preview checkpoint for the Marigold Normals pipeline. Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. The surface normals predictions are unit-length 3D vectors with values in the range from -1 to 1. *This checkpoint will be phased out after the release of `v1-0` version.* |
| [prs-eth/marigold-normals-lcm-v0-1](https://huggingface.co/prs-eth/marigold-normals-lcm-v0-1) | Normals | The fast Marigold Normals checkpoint, fine-tuned from `prs-eth/marigold-normals-v0-1`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. *This checkpoint will be phased out after the release of `v1-0` version.* |
The examples below are mostly given for depth prediction, but they can be universally applied with other supported modalities.
@@ -76,13 +76,13 @@ Below are the raw and the visualized predictions; as can be seen, dark areas (mu
<div class="flex gap-4">
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth_16bit.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth_16bit.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Predicted depth (16-bit PNG)
</figcaption>
</div>
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Predicted depth visualization (Spectral)
</figcaption>
@@ -115,7 +115,7 @@ Below is the visualized prediction:
<div class="flex gap-4" style="justify-content: center; width: 100%;">
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_normals.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_normals.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Predicted surface normals visualization
</figcaption>
@@ -133,7 +133,7 @@ The above quick start snippets are already optimized for speed: they load the LC
The `pipe(image)` call completes in 280ms on RTX 3090 GPU.
Internally, the input image is encoded with the Stable Diffusion VAE encoder, then the U-Net performs one denoising step, and finally, the prediction latent is decoded with the VAE decoder into pixel space.
In this case, two out of three module calls are dedicated to converting between pixel and latent space of LDM.
Because Marigold's latent space is compatible with the base Stable Diffusion, it is possible to speed up the pipeline call by more than 3x (85ms on RTX 3090) by using a [lightweight replacement of the SD VAE](autoencoder_tiny):
Because Marigold's latent space is compatible with the base Stable Diffusion, it is possible to speed up the pipeline call by more than 3x (85ms on RTX 3090) by using a [lightweight replacement of the SD VAE](../api/models/autoencoder_tiny):
```diff
import diffusers
@@ -151,7 +151,7 @@ Because Marigold's latent space is compatible with the base Stable Diffusion, it
depth = pipe(image)
```
As suggested in [Optimizations](torch2.0), adding `torch.compile` may squeeze extra performance depending on the target hardware:
As suggested in [Optimizations](../optimization/torch2.0#torch.compile), adding `torch.compile` may squeeze extra performance depending on the target hardware:
```diff
import diffusers
@@ -173,13 +173,13 @@ With the above speed optimizations, Marigold delivers predictions with more deta
<div class="flex gap-4">
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_depth.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Marigold LCM fp16 with Tiny AutoEncoder
</figcaption>
</div>
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/bfe7cb56ca1cc0811b328212472350879dfa7f8b/marigold/einstein_depthanything_large.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/einstein_depthanything_large.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Depth Anything Large
</figcaption>
@@ -224,13 +224,13 @@ vis[0].save("einstein_normals.png")
<div class="flex gap-4">
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_lcm_normals.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_normals.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Surface normals, no ensembling
</figcaption>
</div>
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_normals.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_normals.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Surface normals, with ensembling
</figcaption>
@@ -303,13 +303,13 @@ uncertainty[0].save("einstein_depth_uncertainty.png")
<div class="flex gap-4">
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_depth_uncertainty.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_depth_uncertainty.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Depth uncertainty
</figcaption>
</div>
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/6838ae9b9148cfe22ce9bb4c0ab0907c757c4010/marigold/marigold_einstein_normals_uncertainty.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_normals_uncertainty.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Surface normals uncertainty
</figcaption>
@@ -327,11 +327,11 @@ This becomes an obvious drawback compared to traditional end-to-end dense regres
<div class="flex gap-4">
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/25024b5443a6c1357492751fd09355bd3f967845/marigold/marigold_obama.gif"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_obama.gif"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">Input video</figcaption>
</div>
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/25024b5443a6c1357492751fd09355bd3f967845/marigold/marigold_obama_depth_independent.gif"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_obama_depth_independent.gif"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">Marigold Depth applied to input video frames independently</figcaption>
</div>
</div>
@@ -351,7 +351,7 @@ path_in = "obama.mp4"
path_out = "obama_depth.gif"
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
"prs-eth/marigold-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
).to(device)
pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
"madebyollin/taesd", torch_dtype=torch.float16
@@ -387,11 +387,11 @@ The result is much more stable now:
<div class="flex gap-4">
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/25024b5443a6c1357492751fd09355bd3f967845/marigold/marigold_obama_depth_independent.gif"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_obama_depth_independent.gif"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">Marigold Depth applied to input video frames independently</figcaption>
</div>
<div style="flex: 1 1 50%; max-width: 50%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/25024b5443a6c1357492751fd09355bd3f967845/marigold/marigold_obama_depth_consistent.gif"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_obama_depth_consistent.gif"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">Marigold Depth with forced latents initialization</figcaption>
</div>
</div>
@@ -414,7 +414,7 @@ image = diffusers.utils.load_image(
)
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
"prs-eth/marigold-lcm-v1-0", torch_dtype=torch.float16, variant="fp16"
"prs-eth/marigold-depth-lcm-v1-0", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
depth_image = pipe(image, generator=generator).prediction
@@ -450,13 +450,13 @@ controlnet_out[0].save("motorcycle_controlnet_out.png")
</figcaption>
</div>
<div style="flex: 1 1 33%; max-width: 33%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/8e61e31f9feb7756c0404ceff26f3f0e5d3fe610/marigold/motorcycle_controlnet_depth.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/motorcycle_controlnet_depth.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
Depth in the format compatible with ControlNet
</figcaption>
</div>
<div style="flex: 1 1 33%; max-width: 33%;">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/8e61e31f9feb7756c0404ceff26f3f0e5d3fe610/marigold/motorcycle_controlnet_out.png"/>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/motorcycle_controlnet_out.png"/>
<figcaption class="mt-1 text-center text-sm text-gray-500">
ControlNet generation, conditioned on depth and prompt: "high quality photo of a sports bike, city"
</figcaption>

View File

@@ -83,6 +83,7 @@ else:
"ControlNetModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"HunyuanDiT2DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"ModelMixin",
@@ -229,6 +230,7 @@ else:
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CycleDiffusionPipeline",
"HunyuanDiTPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -487,6 +489,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
HunyuanDiT2DModel,
I2VGenXLUNet,
Kandinsky3UNet,
ModelMixin,
@@ -611,6 +614,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDMPipeline,
CLIPImageProjection,
CycleDiffusionPipeline,
HunyuanDiTPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,

View File

@@ -38,6 +38,7 @@ if is_torch_available():
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
@@ -78,6 +79,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
DiTTransformer2DModel,
DualTransformer2DModel,
HunyuanDiT2DModel,
PixArtTransformer2DModel,
PriorTransformer,
T5FilmDecoder,

View File

@@ -50,6 +50,18 @@ def get_activation(act_fn: str) -> nn.Module:
raise ValueError(f"Unsupported activation function: {act_fn}")
class FP32SiLU(nn.Module):
r"""
SiLU activation function with input upcasted to torch.float32.
"""
def __init__(self):
super().__init__()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.

View File

@@ -103,6 +103,7 @@ class Attention(nn.Module):
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
@@ -161,6 +162,15 @@ class Attention(nn.Module):
else:
self.spatial_norm = None
if qk_norm is None:
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
@@ -1426,6 +1436,104 @@ class AttnProcessor2_0:
return hidden_states
class HunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses

View File

@@ -16,10 +16,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate
from .activations import get_activation
from .activations import FP32SiLU, get_activation
from .attention_processor import Attention
@@ -135,6 +136,7 @@ class PatchEmbed(nn.Module):
flatten=True,
bias=True,
interpolation_scale=1,
pos_embed_type="sincos",
):
super().__init__()
@@ -156,10 +158,18 @@ class PatchEmbed(nn.Module):
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
if pos_embed_type is None:
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_2d_sincos_pos_embed(
embed_dim,
int(num_patches**0.5),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
@@ -169,6 +179,8 @@ class PatchEmbed(nn.Module):
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
if self.pos_embed is None:
return latent.to(latent.dtype)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
@@ -187,6 +199,113 @@ class PatchEmbed(nn.Module):
return (latent + pos_embed).to(latent.dtype)
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -507,6 +626,88 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
return conditioning
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x.squeeze(0)
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
)
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim,
hidden_size=embedding_dim * 4,
out_features=embedding_dim,
act_fn="silu_fp32",
)
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
# extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
# extra condition2: image meta size embdding
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)
# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
return conditioning
class TextTimeEmbedding(nn.Module):
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
super().__init__()
@@ -793,11 +994,18 @@ class PixArtAlphaTextProjection(nn.Module):
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, num_tokens=120):
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu_fp32":
self.act_1 = FP32SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)

View File

@@ -176,7 +176,8 @@ class AdaLayerNormContinuous(nn.Module):
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x

View File

@@ -4,6 +4,7 @@ from ...utils import is_torch_available
if is_torch_available():
from .dit_transformer_2d import DiTTransformer2DModel
from .dual_transformer_2d import DualTransformer2DModel
from .hunyuan_transformer_2d import HunyuanDiT2DModel
from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder

View File

@@ -0,0 +1,427 @@
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, HunyuanAttnProcessor2_0
from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
PixArtAlphaTextProjection,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
).to(origin_dtype)
class AdaLayerNormShift(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim)
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
x = self.norm(x) + shift.unsqueeze(dim=1)
return x
@maybe_allow_in_graph
class HunyuanDiTBlock(nn.Module):
r"""
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
QKNorm
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of headsto use for multi-head attention.
cross_attention_dim (`int`,*optional*):
The size of the encoder_hidden_states vector for cross attention.
dropout(`float`, *optional*, defaults to 0.0):
The dropout probability to use.
activation_fn (`str`,*optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward. .
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, *optional*, defaults to 1e-6):
A small constant added to the denominator in normalization layers to prevent division by zero.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*):
The size of the hidden layer in the feed-forward block. Defaults to `None`.
ff_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the feed-forward block.
skip (`bool`, *optional*, defaults to `False`):
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
qk_norm (`bool`, *optional*, defaults to `True`):
Whether to use normalization in QK calculation. Defaults to `True`.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
cross_attention_dim: int = 1024,
dropout=0.0,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-6,
final_dropout: bool = False,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
skip: bool = False,
qk_norm: bool = True,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# NOTE: when new version comes, check norm2 and norm 3
# 1. Self-Attn
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor=HunyuanAttnProcessor2_0(),
)
# 2. Cross-Attn
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor=HunyuanAttnProcessor2_0(),
)
# 3. Feed-forward
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout, ### 0.0
activation_fn=activation_fn, ### approx GeLU
final_dropout=final_dropout, ### 0.0
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
bias=ff_bias,
)
# 4. Skip Connection
if skip:
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
self.skip_linear = nn.Linear(2 * dim, dim)
else:
self.skip_linear = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb=None,
skip=None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([hidden_states, skip], dim=-1)
cat = self.skip_norm(cat)
hidden_states = self.skip_linear(cat)
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
attn_output = self.attn1(
norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_output
# 2. Cross-Attention
hidden_states = hidden_states + self.attn2(
self.norm2(hidden_states),
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
mlp_inputs = self.norm3(hidden_states)
hidden_states = hidden_states + self.ff(mlp_inputs)
return hidden_states
class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88):
The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
patch_size (`int`, *optional*):
The size of the patch to use for the input.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward.
sample_size (`int`, *optional*):
The width of the latent images. This is fixed during training since it is used to learn a number of
position embeddings.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
cross_attention_dim (`int`, *optional*):
The number of dimension in the clip text embedding.
hidden_size (`int`, *optional*):
The size of hidden layer in the conditioning embedding layers.
num_layers (`int`, *optional*, defaults to 1):
The number of layers of Transformer blocks to use.
mlp_ratio (`float`, *optional*, defaults to 4.0):
The ratio of the hidden layer size to the input size.
learn_sigma (`bool`, *optional*, defaults to `True`):
Whether to predict variance.
cross_attention_dim_t5 (`int`, *optional*):
The number dimensions in t5 text embedding.
pooled_projection_dim (`int`, *optional*):
The size of the pooled projection.
text_len (`int`, *optional*):
The length of the clip text embedding.
text_len_t5 (`int`, *optional*):
The length of the T5 text embedding.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "gelu-approximate",
sample_size=32,
hidden_size=1152,
num_layers: int = 28,
mlp_ratio: float = 4.0,
learn_sigma: bool = True,
cross_attention_dim: int = 1024,
norm_type: str = "layer_norm",
cross_attention_dim_t5: int = 2048,
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
):
super().__init__()
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.num_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.text_embedder = PixArtAlphaTextProjection(
in_features=cross_attention_dim_t5,
hidden_size=cross_attention_dim_t5 * 4,
out_features=cross_attention_dim,
act_fn="silu_fp32",
)
self.text_embedding_padding = nn.Parameter(
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
in_channels=in_channels,
embed_dim=hidden_size,
patch_size=patch_size,
pos_embed_type=None,
)
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
hidden_size,
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
)
# HunyuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunyuanDiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=layer > num_layers // 2,
)
for layer in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
return_dict=True,
):
"""
The [`HunyuanDiT2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
text_embedding_mask: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of `BertModel`.
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Conditional embedding indicate the style
image_rotary_emb (`torch.Tensor`):
The image rotary embeddings to apply on query and key tensors during attention calculation.
return_dict: bool
Whether to return a dictionary.
"""
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states)
temb = self.time_extra_emb(
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
) # [B, D]
# text projection
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
encoder_hidden_states_t5 = self.text_embedder(
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.config.num_layers // 2:
skip = skips.pop()
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
skip=skip,
) # (N, L, D)
else:
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
) # (N, L, D)
if layer < (self.config.num_layers // 2 - 1):
skips.append(hidden_states)
# final layer
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
hidden_states = self.proj_out(hidden_states)
# (N, L, patch_size ** 2 * out_channels)
# unpatchify: (N, out_channels, H, W)
patch_size = self.pos_embed.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -150,6 +150,7 @@ else:
"IFPipeline",
"IFSuperResolutionPipeline",
]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
@@ -418,6 +419,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
KandinskyCombinedPipeline,

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hunyuandit"] = ["HunyuanDiTPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_hunyuandit import HunyuanDiTPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,881 @@
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, HunyuanDiT2DModel
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import HunyuanDiTPipeline
>>> pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT", torch_dtype=torch.float16)
>>> pipe.to("cuda")
>>> # You may also use English prompt as HunyuanDiT supports both English and Chinese
>>> # prompt = "An astronaut riding a horse"
>>> prompt = "一个宇航员在骑马"
>>> image = pipe(prompt).images[0]
```
"""
STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
SUPPORTED_SHAPE = [
(1024, 1024),
(1280, 1280), # 1:1
(1024, 768),
(1152, 864),
(1280, 960), # 4:3
(768, 1024),
(864, 1152),
(960, 1280), # 3:4
(1280, 768), # 16:9
(768, 1280), # 9:16
]
def map_to_standard_shapes(target_width, target_height):
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def get_resize_crop_region_for_grid(src, tgt_size):
th = tw = tgt_size
h, w = src
r = h / w
# resize
if r > 1:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class HunyuanDiTPipeline(DiffusionPipeline):
r"""
Pipeline for English/Chinese-to-image generation using HunyuanDiT.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
ourselves)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use
`sdxl-vae-fp16-fix`.
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
HunyuanDiT uses a fine-tuned [bilingual CLIP].
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
transformer ([`HunyuanDiT2DModel`]):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = [
"safety_checker",
"feature_extractor",
"text_encoder_2",
"tokenizer_2",
"text_encoder",
"tokenizer",
]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"prompt_embeds_2",
"negative_prompt_embeds_2",
]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: BertModel,
tokenizer: BertTokenizer,
transformer: HunyuanDiT2DModel,
scheduler: DDPMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
text_encoder_2=text_encoder_2,
)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = self.transformer.config.sample_size
def encode_prompt(
self,
prompt: str,
device: torch.device,
dtype: torch.dtype,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: Optional[int] = None,
text_encoder_index: int = 0,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
text_encoder_index (`int`, *optional*):
Index of the text encoder to use. `0` for clip and `1` for T5.
"""
tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = tokenizers[text_encoder_index]
text_encoder = text_encoders[text_encoder_index]
if max_sequence_length is None:
if text_encoder_index == 0:
max_length = 77
if text_encoder_index == 1:
max_length = 256
else:
max_length = max_sequence_length
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=prompt_attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
attention_mask=negative_prompt_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
prompt_embeds_2=None,
negative_prompt_embeds_2=None,
prompt_attention_mask_2=None,
negative_prompt_attention_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
raise ValueError(
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
raise ValueError(
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
f" {negative_prompt_embeds_2.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_attention_mask_2: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = (1024, 1024),
target_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
use_resolution_binning: bool = True,
):
r"""
The call function to the pipeline for generation with HunyuanDiT.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A callback function or a list of callback functions to be called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
The target size of the image. Used to calculate the time ids.
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
The top left coordinates of the crop. Used to calculate the time ids.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = int((height // 16) * 16)
width = int((width // 16) * 16)
if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:
width, height = map_to_standard_shapes(width, height)
height = int(height)
width = int(width)
logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=77,
text_encoder_index=0,
)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
max_sequence_length=256,
text_encoder_index=1,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7 create image_rotary_emb, style embedding & time ids
grid_height = height // 8 // self.transformer.config.patch_size
grid_width = width // 8 // self.transformer.config.patch_size
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
)
style = torch.tensor([0], device=device)
target_size = target_size or (height, width)
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
style = torch.cat([style] * 2, dim=0)
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
prompt_embeds_2 = prompt_embeds_2.to(device=device)
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
batch_size * num_images_per_prompt, 1
)
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -157,19 +157,19 @@ def compute_dream_and_update_latents(
with torch.no_grad():
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
noisy_latents, target = (None, None)
_noisy_latents, _target = (None, None)
if noise_scheduler.config.prediction_type == "epsilon":
predicted_noise = pred
delta_noise = (noise - predicted_noise).detach()
delta_noise.mul_(dream_lambda)
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
target = target.add(delta_noise)
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
_target = target.add(delta_noise)
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError("DREAM has not been implemented for v-prediction")
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
return noisy_latents, target
return _noisy_latents, _target
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:

View File

@@ -122,6 +122,21 @@ class DiTTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class HunyuanDiT2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -212,6 +212,21 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class I2VGenXLPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

View File

@@ -0,0 +1,266 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, BertModel, T5EncoderModel
from diffusers import (
AutoencoderKL,
DDPMScheduler,
HunyuanDiT2DModel,
HunyuanDiTPipeline,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HunyuanDiTPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
def get_dummy_components(self):
torch.manual_seed(0)
transformer = HunyuanDiT2DModel(
sample_size=16,
num_layers=2,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
in_channels=4,
cross_attention_dim=32,
cross_attention_dim_t5=32,
pooled_projection_dim=16,
hidden_size=24,
activation_fn="gelu-approximate",
)
torch.manual_seed(0)
vae = AutoencoderKL()
scheduler = DDPMScheduler()
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"use_resolution_binning": False,
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 16, 16, 3))
expected_slice = np.array(
[0.56939435, 0.34541583, 0.35915792, 0.46489206, 0.38775963, 0.45004836, 0.5957267, 0.59481275, 0.33287364]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_sequential_cpu_offload_forward_pass(self):
# TODO(YiYi) need to fix later
pass
def test_sequential_offload_forward_pass_twice(self):
# TODO(YiYi) need to fix later
pass
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(
expected_max_diff=1e-3,
)
def test_save_load_optional_components(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs["prompt"]
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = pipe.encode_prompt(
prompt,
device=torch_device,
dtype=torch.float32,
text_encoder_index=1,
)
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"prompt_embeds_2": prompt_embeds_2,
"prompt_attention_mask_2": prompt_attention_mask_2,
"negative_prompt_embeds_2": negative_prompt_embeds_2,
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"prompt_embeds_2": prompt_embeds_2,
"prompt_attention_mask_2": prompt_attention_mask_2,
"negative_prompt_embeds_2": negative_prompt_embeds_2,
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4)
@slow
@require_torch_gpu
class HunyuanDiTPipelineIntegrationTests(unittest.TestCase):
prompt = "一个宇航员在骑马"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_hunyuan_dit_1024(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = HunyuanDiTPipeline.from_pretrained(
"XCLiu/HunyuanDiT-0523", revision="refs/pr/2", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
prompt = self.prompt
image = pipe(
prompt=prompt, height=1024, width=1024, generator=generator, num_inference_steps=2, output_type="np"
).images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array(
[0.48388672, 0.33789062, 0.30737305, 0.47875977, 0.25097656, 0.30029297, 0.4440918, 0.26953125, 0.30078125]
)
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
assert max_diff < 1e-3, f"Max diff is too high. got {image_slice.flatten()}"