Compare commits

...

5 Commits

Author SHA1 Message Date
Sayak Paul
7de51b826c [lora] support more ZImage LoRAs (#12790)
up

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-12-04 09:01:11 -10:00
Jiang
cd00ba685b fix spatial compression ratio error for AutoEncoderKLWan doing tiled encode (#12753)
fix spatial compression ratio compute error for AutoEncoderKLWan

Co-authored-by: lirui.926 <lirui.926@bytedance.com>
2025-12-04 08:57:13 -10:00
David El Malih
2842c14c5f Improve docstrings and type hints in scheduling_unipc_multistep.py (#12767)
refactor: add type hints and update docstrings for UniPCMultistepScheduler parameters and methods.
2025-12-04 10:10:54 -08:00
Sayak Paul
c318686090 Update attention_backends.md to format kernels (#12757) 2025-12-04 07:48:23 -08:00
hlky
6028613226 Z-Image-Turbo from_single_file (#12756)
* Z-Image-Turbo `from_single_file`

* compute_dtype

* -device cast
2025-12-04 20:22:48 +05:30
7 changed files with 126 additions and 38 deletions

View File

@@ -32,7 +32,7 @@ This guide will show you how to set and use the different attention backends.
The [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [kernel](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [`kernels`](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
> [!NOTE]
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
@@ -156,4 +156,4 @@ Refer to the table below for a complete list of available attention backends and
| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |
</details>
</details>

View File

@@ -2417,6 +2417,17 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
state_dict = {convert_key(k): v for k, v in state_dict.items()}
def normalize_out_key(k: str) -> str:
if ".to_out" in k:
return k
return re.sub(
r"\.out(?=\.(?:lora_down|lora_up)\.weight$|\.alpha$)",
".to_out.0",
k,
)
state_dict = {normalize_out_key(k): v for k, v in state_dict.items()}
has_default = any("default." in k for k in state_dict)
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}

View File

@@ -49,6 +49,7 @@ from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
convert_z_image_transformer_checkpoint_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
@@ -167,6 +168,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ZImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}

View File

@@ -120,6 +120,7 @@ CHECKPOINT_KEY_NAMES = {
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
"z-image-turbo": "cap_embedder.0.weight",
"sana": [
"blocks.0.cross_attn.q_linear.weight",
"blocks.0.cross_attn.q_linear.bias",
@@ -218,6 +219,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
}
# Use to configure model sample size when original config is provided
@@ -721,6 +723,12 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "instruct-pix2pix"
elif (
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
):
model_type = "z-image-turbo"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
model_type = "lumina2"
@@ -3824,3 +3832,56 @@ def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
Z_IMAGE_KEYS_RENAME_DICT = {
"final_layer.": "all_final_layer.2-1.",
"x_embedder.": "all_x_embedder.2-1.",
".attention.out.bias": ".attention.to_out.0.bias",
".attention.k_norm.weight": ".attention.norm_k.weight",
".attention.q_norm.weight": ".attention.norm_q.weight",
".attention.out.weight": ".attention.to_out.0.weight",
}
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
if ".attention.qkv.weight" not in key:
return
fused_qkv_weight = state_dict.pop(key)
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
state_dict[new_q_name] = to_q_weight
state_dict[new_k_name] = to_k_weight
state_dict[new_v_name] = to_v_weight
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {
".attention.qkv.weight": convert_z_image_fused_attention,
}
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle single file --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict

View File

@@ -1259,14 +1259,20 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
_, _, num_frames, height, width = x.shape
encode_spatial_compression_ratio = self.spatial_compression_ratio
if self.config.patch_size is not None:
assert encode_spatial_compression_ratio % self.config.patch_size == 0
encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size
latent_height = height // encode_spatial_compression_ratio
latent_width = width // encode_spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // encode_spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // encode_spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // encode_spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // encode_spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width

View File

@@ -63,8 +63,11 @@ class TimestepEmbedder(nn.Module):
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
elif compute_dtype is not None:
t_freq = t_freq.to(compute_dtype)
t_emb = self.mlp(t_freq)
return t_emb

View File

@@ -77,7 +77,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -127,19 +127,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, default `2`):
solver_order (`int`, defaults to `2`):
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
`sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -149,7 +149,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, defaults to `True`):
Whether to use the updating algorithm on the predicted x0.
solver_type (`str`, default `bh2`):
solver_type (`"bh1"` or `"bh2"`, defaults to `"bh2"`):
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
@@ -171,12 +171,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`):
final_sigmas_type (`"zero"` or `"sigma_min"`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
rescale_betas_zero_snr (`bool`, defaults to `False`):
@@ -194,30 +194,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh2",
solver_type: Literal["bh1", "bh2"] = "bh2",
lower_order_final: bool = True,
disable_corrector: List[int] = [],
solver_p: SchedulerMixin = None,
solver_p: Optional[SchedulerMixin] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace",
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
time_shift_type: Literal["exponential"] = "exponential",
) -> None:
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -279,21 +279,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -304,8 +304,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
):
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -314,6 +314,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
mu (`float`, *optional*):
Optional mu parameter for dynamic shifting when using exponential time shift type.
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
@@ -475,7 +477,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
"""
Convert sigma values to corresponding timestep values through interpolation.
@@ -512,7 +514,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert sigma values to alpha_t and sigma_t values.
@@ -534,7 +536,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
@@ -1030,7 +1032,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
@@ -1060,11 +1062,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
@@ -1192,5 +1194,5 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps