mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-25 05:44:52 +08:00
Compare commits
4 Commits
torchao-co
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6b6a7181e | ||
|
|
52766e6a69 | ||
|
|
973a077c6a | ||
|
|
0c4f6c9cff |
@@ -21,8 +21,8 @@ from transformers import (
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
CLIPImageProcessor,
|
||||
MT5Tokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2=T5EncoderModel,
|
||||
tokenizer_2=MT5Tokenizer,
|
||||
tokenizer_2=T5Tokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -29,13 +29,52 @@ hf download nvidia/Cosmos-Predict2.5-2B
|
||||
|
||||
Convert checkpoint
|
||||
```bash
|
||||
# pre-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/cosmos-p2.5-base-2b \
|
||||
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
|
||||
--save_pipeline
|
||||
|
||||
# post-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
## 14B
|
||||
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2.5-14B
|
||||
```
|
||||
|
||||
```bash
|
||||
# pre-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-14B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
|
||||
--save_pipeline
|
||||
|
||||
# post-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-14B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
@@ -298,6 +337,25 @@ TRANSFORMER_CONFIGS = {
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
"Cosmos-2.5-Predict-Base-14B": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
|
||||
"extra_pos_embed_type": None,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
|
||||
@@ -675,6 +675,7 @@ else:
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
)
|
||||
@@ -1386,6 +1387,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -170,6 +170,21 @@ class FeedForward(nn.Module):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.select_per_token
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
@@ -220,12 +235,37 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
@@ -493,112 +533,93 @@ class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
|
||||
def patchify_and_embed(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
"""Patchify for basic mode: single image per batch item."""
|
||||
device = all_image[0].device
|
||||
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
cap_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_pad_mask.append(
|
||||
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
|
||||
for image, cap_feat in zip(all_image, all_cap_feats):
|
||||
# Caption
|
||||
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
|
||||
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
|
||||
)
|
||||
all_cap_out.append(cap_out)
|
||||
all_cap_pos_ids.append(cap_pos_ids)
|
||||
all_cap_pad_mask.append(cap_pad_mask)
|
||||
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padded_pos_ids = torch.cat(
|
||||
[
|
||||
image_ori_pos_ids,
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1),
|
||||
],
|
||||
dim=0,
|
||||
# Image
|
||||
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
|
||||
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
|
||||
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
|
||||
)
|
||||
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
|
||||
# pad mask
|
||||
image_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_image_pad_mask.append(
|
||||
image_pad_mask
|
||||
if image_padding_len > 0
|
||||
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat(
|
||||
[image, image[-1:].repeat(image_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
|
||||
all_img_out.append(img_out)
|
||||
all_img_size.append(size)
|
||||
all_img_pos_ids.append(img_pos_ids)
|
||||
all_img_pad_mask.append(img_pad_mask)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_img_out,
|
||||
all_cap_out,
|
||||
all_img_size,
|
||||
all_img_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_img_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -32,6 +32,7 @@ from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
X_PAD_DIM = 64
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
@@ -152,6 +153,20 @@ class ZSingleStreamAttnProcessor:
|
||||
return output
|
||||
|
||||
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
@@ -215,12 +230,37 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
@@ -252,9 +292,21 @@ class FinalLayer(nn.Module):
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation
|
||||
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
|
||||
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
|
||||
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Original global modulation
|
||||
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
scale = scale.unsqueeze(1)
|
||||
|
||||
x = self.norm_final(x) * scale
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -325,6 +377,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2560,
|
||||
siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
@@ -386,6 +439,31 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
|
||||
|
||||
# Optional SigLIP components (for Omni variant)
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
2000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
@@ -402,259 +480,561 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
def unpatchify(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
size: List[Tuple],
|
||||
patch_size,
|
||||
f_patch_size,
|
||||
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
if x_pos_offsets is not None:
|
||||
# Omni: extract target image from unified sequence (cond_images + target)
|
||||
result = []
|
||||
for i in range(bsz):
|
||||
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
|
||||
cu_len = 0
|
||||
x_item = None
|
||||
for j in range(len(size[i])):
|
||||
if size[i][j] is None:
|
||||
ori_len = 0
|
||||
pad_len = SEQ_MULTI_OF
|
||||
cu_len += pad_len + ori_len
|
||||
else:
|
||||
F, H, W = size[i][j]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
x_item = (
|
||||
unified_x[cu_len : cu_len + ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
cu_len += ori_len + pad_len
|
||||
result.append(x_item) # Return only the last (target) image
|
||||
return result
|
||||
else:
|
||||
# Original mode: simple unpatchify
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
def patchify_and_embed(
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
def patchify_and_embed(
|
||||
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
|
||||
):
|
||||
"""Patchify for basic mode: single image per batch item."""
|
||||
device = all_image[0].device
|
||||
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
cap_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_pad_mask.append(
|
||||
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
|
||||
for image, cap_feat in zip(all_image, all_cap_feats):
|
||||
# Caption
|
||||
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
|
||||
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
|
||||
)
|
||||
all_cap_out.append(cap_out)
|
||||
all_cap_pos_ids.append(cap_pos_ids)
|
||||
all_cap_pad_mask.append(cap_pad_mask)
|
||||
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padded_pos_ids = torch.cat(
|
||||
[
|
||||
image_ori_pos_ids,
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1),
|
||||
],
|
||||
dim=0,
|
||||
# Image
|
||||
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
|
||||
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
|
||||
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
|
||||
)
|
||||
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
|
||||
# pad mask
|
||||
image_pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
all_image_pad_mask.append(
|
||||
image_pad_mask
|
||||
if image_padding_len > 0
|
||||
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat(
|
||||
[image, image[-1:].repeat(image_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
|
||||
all_img_out.append(img_out)
|
||||
all_img_size.append(size)
|
||||
all_img_pos_ids.append(img_pos_ids)
|
||||
all_img_pad_mask.append(img_pad_mask)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_img_out,
|
||||
all_cap_out,
|
||||
all_img_size,
|
||||
all_img_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_img_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def forward(
|
||||
def patchify_and_embed_omni(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
return_dict: bool = True,
|
||||
all_x: List[List[torch.Tensor]],
|
||||
all_cap_feats: List[List[torch.Tensor]],
|
||||
all_siglip_feats: List[List[torch.Tensor]],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
images_noise_mask: List[List[int]],
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
"""Patchify for omni mode: multiple images per batch item with noise masks."""
|
||||
bsz = len(all_x)
|
||||
device = all_x[0][-1].device
|
||||
dtype = all_x[0][-1].dtype
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
|
||||
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
for i in range(bsz):
|
||||
num_images = len(all_x[i])
|
||||
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
|
||||
cap_end_pos = []
|
||||
cap_cu_len = 1
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
# Process captions
|
||||
for j, cap_item in enumerate(all_cap_feats[i]):
|
||||
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
|
||||
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
|
||||
cap_item,
|
||||
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
|
||||
(cap_cu_len, 0, 0),
|
||||
device,
|
||||
noise_val,
|
||||
)
|
||||
cap_feats_list.append(cap_out)
|
||||
cap_pos_list.append(cap_pos)
|
||||
cap_mask_list.append(cap_mask)
|
||||
cap_lens.append(cap_len)
|
||||
cap_noise.extend(cap_nm)
|
||||
cap_cu_len += len(cap_item)
|
||||
cap_end_pos.append(cap_cu_len)
|
||||
cap_cu_len += 2 # for image vae and siglip tokens
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
|
||||
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
|
||||
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
|
||||
all_cap_len.append(cap_lens)
|
||||
all_cap_noise_mask.append(cap_noise)
|
||||
|
||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
|
||||
# Process images
|
||||
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
|
||||
for j, x_item in enumerate(all_x[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if x_item is not None:
|
||||
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
|
||||
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
|
||||
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
|
||||
)
|
||||
x_size.append(size)
|
||||
else:
|
||||
x_len = SEQ_MULTI_OF
|
||||
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
|
||||
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
|
||||
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
|
||||
x_nm = [noise_val] * x_len
|
||||
x_size.append(None)
|
||||
x_feats_list.append(x_out)
|
||||
x_pos_list.append(x_pos)
|
||||
x_mask_list.append(x_mask)
|
||||
x_lens.append(x_len)
|
||||
x_noise.extend(x_nm)
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
|
||||
all_x_out.append(torch.cat(x_feats_list, dim=0))
|
||||
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
|
||||
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
|
||||
all_x_size.append(x_size)
|
||||
all_x_len.append(x_lens)
|
||||
all_x_noise_mask.append(x_noise)
|
||||
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
# Process siglip
|
||||
if all_siglip_feats[i] is None:
|
||||
all_sig_len.append([0] * num_images)
|
||||
all_sig_out.append(None)
|
||||
else:
|
||||
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
|
||||
for j, sig_item in enumerate(all_siglip_feats[i]):
|
||||
noise_val = images_noise_mask[i][j]
|
||||
if sig_item is not None:
|
||||
sig_H, sig_W, sig_C = sig_item.size()
|
||||
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
|
||||
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
|
||||
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
|
||||
)
|
||||
# Scale position IDs to match x resolution
|
||||
if x_size[j] is not None:
|
||||
sig_pos = sig_pos.float()
|
||||
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
|
||||
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
|
||||
sig_pos = sig_pos.to(torch.int32)
|
||||
else:
|
||||
sig_len = SEQ_MULTI_OF
|
||||
sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device)
|
||||
sig_pos = (
|
||||
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
|
||||
)
|
||||
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
|
||||
sig_nm = [noise_val] * sig_len
|
||||
sig_feats_list.append(sig_out)
|
||||
sig_pos_list.append(sig_pos)
|
||||
sig_mask_list.append(sig_mask)
|
||||
sig_lens.append(sig_len)
|
||||
sig_noise.extend(sig_nm)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.noise_refiner:
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
else:
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
|
||||
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
|
||||
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
|
||||
all_sig_len.append(sig_lens)
|
||||
all_sig_noise_mask.append(sig_noise)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
# Compute x position offsets
|
||||
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(
|
||||
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
|
||||
return (
|
||||
all_x_out,
|
||||
all_cap_out,
|
||||
all_sig_out,
|
||||
all_x_size,
|
||||
all_x_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_sig_pos_ids,
|
||||
all_x_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
all_sig_pad_mask,
|
||||
all_x_pos_offsets,
|
||||
all_x_noise_mask,
|
||||
all_cap_noise_mask,
|
||||
all_sig_noise_mask,
|
||||
)
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
|
||||
def _prepare_sequence(
|
||||
self,
|
||||
feats: List[torch.Tensor],
|
||||
pos_ids: List[torch.Tensor],
|
||||
inner_pad_mask: List[torch.Tensor],
|
||||
pad_token: torch.nn.Parameter,
|
||||
noise_mask: Optional[List[List[int]]] = None,
|
||||
device: torch.device = None,
|
||||
):
|
||||
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
|
||||
item_seqlens = [len(f) for f in feats]
|
||||
max_seqlen = max(item_seqlens)
|
||||
bsz = len(feats)
|
||||
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
# Pad token
|
||||
feats_cat = torch.cat(feats, dim=0)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token
|
||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
# RoPE
|
||||
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
|
||||
|
||||
# unified
|
||||
# Pad to batch
|
||||
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
||||
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if noise_mask is not None:
|
||||
noise_mask_tensor = pad_sequence(
|
||||
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)[:, : feats.shape[1]]
|
||||
|
||||
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
|
||||
|
||||
def _build_unified_sequence(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_freqs: torch.Tensor,
|
||||
x_seqlens: List[int],
|
||||
x_noise_mask: Optional[List[List[int]]],
|
||||
cap: torch.Tensor,
|
||||
cap_freqs: torch.Tensor,
|
||||
cap_seqlens: List[int],
|
||||
cap_noise_mask: Optional[List[List[int]]],
|
||||
siglip: Optional[torch.Tensor],
|
||||
siglip_freqs: Optional[torch.Tensor],
|
||||
siglip_seqlens: Optional[List[int]],
|
||||
siglip_noise_mask: Optional[List[List[int]]],
|
||||
omni_mode: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Build unified sequence: x, cap, and optionally siglip.
|
||||
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
|
||||
"""
|
||||
bsz = len(x_seqlens)
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
unified_freqs = []
|
||||
unified_noise_mask = []
|
||||
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
if omni_mode:
|
||||
# Omni: [cap, x, siglip]
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
sig_len = siglip_seqlens[i]
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
|
||||
unified_freqs.append(
|
||||
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
|
||||
)
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(
|
||||
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
|
||||
)
|
||||
)
|
||||
else:
|
||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
|
||||
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
|
||||
unified_noise_mask.append(
|
||||
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
|
||||
)
|
||||
else:
|
||||
# Basic: [x, cap]
|
||||
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
|
||||
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = self._gradient_checkpointing_func(
|
||||
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
if controlnet_block_samples is not None:
|
||||
if layer_idx in controlnet_block_samples:
|
||||
unified = unified + controlnet_block_samples[layer_idx]
|
||||
# Compute unified seqlens
|
||||
if omni_mode:
|
||||
if siglip is not None and siglip_seqlens is not None:
|
||||
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
|
||||
else:
|
||||
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
|
||||
else:
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||
if controlnet_block_samples is not None:
|
||||
if layer_idx in controlnet_block_samples:
|
||||
unified = unified + controlnet_block_samples[layer_idx]
|
||||
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
max_seqlen = max(unified_seqlens)
|
||||
|
||||
if not return_dict:
|
||||
return (x,)
|
||||
# Pad to batch
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
|
||||
|
||||
return Transformer2DModelOutput(sample=x)
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
if omni_mode:
|
||||
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
|
||||
:, : unified.shape[1]
|
||||
]
|
||||
|
||||
return unified, unified_freqs, attn_mask, noise_mask_tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[List[torch.Tensor], List[List[torch.Tensor]]],
|
||||
t,
|
||||
cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]],
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
|
||||
siglip_feats: Optional[List[List[torch.Tensor]]] = None,
|
||||
image_noise_mask: Optional[List[List[int]]] = None,
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine
|
||||
-> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify
|
||||
"""
|
||||
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
|
||||
omni_mode = isinstance(x[0], list)
|
||||
device = x[0][-1].device if omni_mode else x[0].device
|
||||
|
||||
if omni_mode:
|
||||
# Dual embeddings: noisy (t) and clean (t=1)
|
||||
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
|
||||
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
|
||||
adaln_input = None
|
||||
else:
|
||||
# Single embedding for all tokens
|
||||
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
|
||||
t_noisy = t_clean = None
|
||||
|
||||
# Patchify
|
||||
if omni_mode:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
siglip_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
siglip_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
siglip_pad_mask,
|
||||
x_pos_offsets,
|
||||
x_noise_mask,
|
||||
cap_noise_mask,
|
||||
siglip_noise_mask,
|
||||
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
|
||||
else:
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_pad_mask,
|
||||
cap_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
|
||||
|
||||
# X embed & refine
|
||||
x_seqlens = [len(xi) for xi in x]
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
|
||||
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
|
||||
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
|
||||
)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x = (
|
||||
self._gradient_checkpointing_func(
|
||||
layer, x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean)
|
||||
)
|
||||
|
||||
# Cap embed & refine
|
||||
cap_seqlens = [len(ci) for ci in cap_feats]
|
||||
cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
|
||||
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
|
||||
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = (
|
||||
self._gradient_checkpointing_func(layer, cap_feats, cap_mask, cap_freqs)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(cap_feats, cap_mask, cap_freqs)
|
||||
)
|
||||
|
||||
# Siglip embed & refine
|
||||
siglip_seqlens = siglip_freqs = None
|
||||
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
|
||||
siglip_seqlens = [len(si) for si in siglip_feats]
|
||||
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
|
||||
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
|
||||
list(siglip_feats.split(siglip_seqlens, dim=0)),
|
||||
siglip_pos_ids,
|
||||
siglip_pad_mask,
|
||||
self.siglip_pad_token,
|
||||
None,
|
||||
device,
|
||||
)
|
||||
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats = (
|
||||
self._gradient_checkpointing_func(layer, siglip_feats, siglip_mask, siglip_freqs)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(siglip_feats, siglip_mask, siglip_freqs)
|
||||
)
|
||||
|
||||
# Unified sequence
|
||||
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
|
||||
x,
|
||||
x_freqs,
|
||||
x_seqlens,
|
||||
x_noise_mask,
|
||||
cap_feats,
|
||||
cap_freqs,
|
||||
cap_seqlens,
|
||||
cap_noise_mask,
|
||||
siglip_feats,
|
||||
siglip_freqs,
|
||||
siglip_seqlens,
|
||||
siglip_noise_mask,
|
||||
omni_mode,
|
||||
device,
|
||||
)
|
||||
|
||||
# Main transformer layers
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
unified = (
|
||||
self._gradient_checkpointing_func(
|
||||
layer, unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
else layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean)
|
||||
)
|
||||
if controlnet_block_samples is not None and layer_idx in controlnet_block_samples:
|
||||
unified = unified + controlnet_block_samples[layer_idx]
|
||||
|
||||
unified = (
|
||||
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
|
||||
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
|
||||
)
|
||||
if omni_mode
|
||||
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
|
||||
)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
|
||||
|
||||
return (x,) if not return_dict else Transformer2DModelOutput(sample=x)
|
||||
|
||||
@@ -411,6 +411,7 @@ else:
|
||||
"ZImagePipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
@@ -856,6 +857,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .lumina import LuminaPipeline
|
||||
from .lumina2 import Lumina2Pipeline
|
||||
from .ovis_image import OvisImagePipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
@@ -119,7 +120,13 @@ from .stable_diffusion_xl import (
|
||||
)
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
|
||||
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
|
||||
from .z_image import (
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
@@ -164,6 +171,10 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
("z-image", ZImagePipeline),
|
||||
("z-image-controlnet", ZImageControlNetPipeline),
|
||||
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
|
||||
("z-image-omni", ZImageOmniPipeline),
|
||||
("ovis", OvisImagePipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -185,7 +185,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -229,7 +229,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
],
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -133,7 +133,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... num_frames=93,
|
||||
... generator=torch.Generator().manual_seed(1),
|
||||
... ).frames[0]
|
||||
>>> # export_to_video(video, "image2world.mp4", fps=16)
|
||||
>>> export_to_video(video, "image2world.mp4", fps=16)
|
||||
|
||||
>>> # Video2World: condition on an input clip and predict a 93-frame world video.
|
||||
>>> prompt = (
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -169,7 +169,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -204,7 +204,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
|
||||
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -173,7 +173,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -208,7 +208,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
tokenizer_2: Optional[T5Tokenizer] = None,
|
||||
pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -26,6 +26,7 @@ else:
|
||||
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
|
||||
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -41,7 +42,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
|
||||
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
|
||||
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
|
||||
|
||||
from .pipeline_z_image_omni import ZImageOmniPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
742
src/diffusers/pipelines/z_image/pipeline_z_image_omni.py
Normal file
742
src/diffusers/pipelines/z_image/pipeline_z_image_omni.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel
|
||||
|
||||
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..flux2.image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImageOmniPipeline
|
||||
|
||||
>>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
||||
>>> # (1) Use flash attention 2
|
||||
>>> # pipe.transformer.set_attention_backend("flash")
|
||||
>>> # (2) Use flash attention 3
|
||||
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... height=1024,
|
||||
... width=1024,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(42),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
siglip: Siglip2VisionModel,
|
||||
siglip_processor: Siglip2ImageProcessorFast,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
siglip=siglip,
|
||||
siglip_processor=siglip_processor,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
# self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
num_condition_images: int = 0,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
num_condition_images: int = 0,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
if num_condition_images == 0:
|
||||
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
|
||||
elif num_condition_images > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
|
||||
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
prompt[i] = prompt_list
|
||||
|
||||
flattened_prompt = []
|
||||
prompt_list_lengths = []
|
||||
|
||||
for i in range(len(prompt)):
|
||||
prompt_list_lengths.append(len(prompt[i]))
|
||||
flattened_prompt.extend(prompt[i])
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
flattened_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
start_idx = 0
|
||||
for i in range(len(prompt_list_lengths)):
|
||||
batch_embeddings = []
|
||||
end_idx = start_idx + prompt_list_lengths[i]
|
||||
for j in range(start_idx, end_idx):
|
||||
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
|
||||
embeddings_list.append(batch_embeddings)
|
||||
start_idx = end_idx
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
batch_size,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latent = (
|
||||
self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor
|
||||
) * self.vae.config.scaling_factor
|
||||
image_latent = image_latent.unsqueeze(1).to(dtype)
|
||||
image_latents.append(image_latent) # (16, 128, 128)
|
||||
|
||||
# image_latents = [image_latents] * batch_size
|
||||
image_latents = [image_latents.copy() for _ in range(batch_size)]
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_siglip_embeds(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
batch_size,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
siglip_embeds = []
|
||||
for image in images:
|
||||
siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device)
|
||||
shape = siglip_inputs.spatial_shapes[0]
|
||||
hidden_state = self.siglip(**siglip_inputs).last_hidden_state
|
||||
B, N, C = hidden_state.shape
|
||||
hidden_state = hidden_state[:, : shape[0] * shape[1]]
|
||||
hidden_state = hidden_state.view(shape[0], shape[1], C)
|
||||
siglip_embeds.append(hidden_state.to(dtype))
|
||||
|
||||
# siglip_embeds = [siglip_embeds] * batch_size
|
||||
siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)]
|
||||
|
||||
return siglip_embeds
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
num_condition_images = len(image) if image is not None else 0
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_condition_images=num_condition_images,
|
||||
)
|
||||
|
||||
# 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2
|
||||
condition_images = []
|
||||
resized_images = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
if height is not None and width is not None:
|
||||
img = self.image_processor._resize_to_target_area(img, height * width)
|
||||
else:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
resized_images.append(img)
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
|
||||
if len(condition_images) > 0:
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
else:
|
||||
height = height or 1024
|
||||
width = width or 1024
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
condition_latents = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents]
|
||||
if self.do_classifier_free_guidance:
|
||||
negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents]
|
||||
|
||||
condition_siglip_embeds = self.prepare_siglip_embeds(
|
||||
images=resized_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds]
|
||||
if self.do_classifier_free_guidance:
|
||||
negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds]
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds]
|
||||
negative_condition_siglip_embeds = [
|
||||
None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds
|
||||
]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
condition_latents_model_input = condition_latents + negative_condition_latents
|
||||
condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
condition_latents_model_input = condition_latents
|
||||
condition_siglip_embeds_model_input = condition_siglip_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
# Combine condition latents with target latent
|
||||
current_batch_size = len(latent_model_input_list)
|
||||
x_combined = [
|
||||
condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size)
|
||||
]
|
||||
# Create noise mask: 0 for condition images (clean), 1 for target image (noisy)
|
||||
image_noise_mask = [
|
||||
[0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size)
|
||||
]
|
||||
|
||||
model_out_list = self.transformer(
|
||||
x=x_combined,
|
||||
t=timestep_model_input,
|
||||
cap_feats=prompt_embeds_model_input,
|
||||
siglip_feats=condition_siglip_embeds_model_input,
|
||||
image_noise_mask=image_noise_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -3917,6 +3917,21 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageOmniPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -671,46 +671,44 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": TorchAoConfig(Int8WeightOnlyConfig()),
|
||||
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
|
||||
},
|
||||
)
|
||||
|
||||
# @unittest.skip(
|
||||
# "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
# "when compiling."
|
||||
# )
|
||||
# def test_torch_compile_with_cpu_offload(self):
|
||||
# # RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
# super().test_torch_compile_with_cpu_offload()
|
||||
@unittest.skip(
|
||||
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
"when compiling."
|
||||
)
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
super().test_torch_compile_with_cpu_offload()
|
||||
|
||||
# @parameterized.expand([False, True])
|
||||
# @unittest.skip(
|
||||
# """
|
||||
# For `use_stream=False`:
|
||||
# - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
|
||||
# is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
|
||||
# For `use_stream=True`:
|
||||
# Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
|
||||
# """
|
||||
# )
|
||||
# def test_torch_compile_with_group_offload_leaf(self, use_stream):
|
||||
# # For use_stream=False:
|
||||
# # If we run group offloading without compilation, we will see:
|
||||
# # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
|
||||
# # When running with compilation, the error ends up being different:
|
||||
# # Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
|
||||
# # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
|
||||
# # Looks like something that will have to be looked into upstream.
|
||||
# # for linear layers, weight.tensor_impl shows cuda... but:
|
||||
# # weight.tensor_impl.{data,scale,zero_point}.device will be cpu
|
||||
@parameterized.expand([False, True])
|
||||
@unittest.skip(
|
||||
"""
|
||||
For `use_stream=False`:
|
||||
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
|
||||
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
|
||||
For `use_stream=True`:
|
||||
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
|
||||
"""
|
||||
)
|
||||
def test_torch_compile_with_group_offload_leaf(self, use_stream):
|
||||
# For use_stream=False:
|
||||
# If we run group offloading without compilation, we will see:
|
||||
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
|
||||
# When running with compilation, the error ends up being different:
|
||||
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
|
||||
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
|
||||
# Looks like something that will have to be looked into upstream.
|
||||
# for linear layers, weight.tensor_impl shows cuda... but:
|
||||
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
|
||||
|
||||
# # For use_stream=True:
|
||||
# # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
|
||||
# super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
|
||||
# For use_stream=True:
|
||||
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
|
||||
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
|
||||
Reference in New Issue
Block a user