mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-25 09:58:17 +08:00
Compare commits
12 Commits
fix-torcha
...
klein-lora
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fffc78d4d | ||
|
|
da6718f080 | ||
|
|
832676d35e | ||
|
|
81f942d878 | ||
|
|
7bbd96da5d | ||
|
|
13f8697ce5 | ||
|
|
62777fa819 | ||
|
|
f1fd515257 | ||
|
|
afdda57f61 | ||
|
|
5fc2bd2c8f | ||
|
|
6350a7690a | ||
|
|
ec739c0441 |
@@ -22,7 +22,7 @@ from typing import Set
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available, is_torchao_available
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -35,54 +35,6 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
|
||||
if not is_torchao_available():
|
||||
return False
|
||||
from torchao.utils import TorchAOBaseTensor
|
||||
|
||||
return isinstance(tensor, TorchAOBaseTensor)
|
||||
|
||||
|
||||
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
|
||||
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
|
||||
cls = type(tensor)
|
||||
names = list(getattr(cls, "tensor_data_names", []))
|
||||
for attr_name in getattr(cls, "optional_tensor_data_names", []):
|
||||
if getattr(tensor, attr_name, None) is not None:
|
||||
names.append(attr_name)
|
||||
return names
|
||||
|
||||
|
||||
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
|
||||
|
||||
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
|
||||
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
|
||||
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
|
||||
that any dict keyed by `id(param)` remains valid.
|
||||
|
||||
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
|
||||
"""
|
||||
torch.utils.swap_tensors(param, source)
|
||||
|
||||
|
||||
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
|
||||
|
||||
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
|
||||
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
|
||||
`cpu_param_dict`).
|
||||
"""
|
||||
for attr_name in _get_torchao_inner_tensor_names(source):
|
||||
setattr(param, attr_name, getattr(source, attr_name))
|
||||
|
||||
|
||||
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
|
||||
"""Record stream for all internal tensors of a TorchAO parameter."""
|
||||
for attr_name in _get_torchao_inner_tensor_names(param):
|
||||
getattr(param, attr_name).record_stream(stream)
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
@@ -205,16 +157,9 @@ class ModuleGroup:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
|
||||
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_swap_torchao_tensor(tensor, moved)
|
||||
else:
|
||||
tensor.data = moved
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
if _is_torchao_tensor(tensor):
|
||||
_record_stream_torchao_tensor(tensor, default_stream)
|
||||
else:
|
||||
tensor.data.record_stream(default_stream)
|
||||
tensor.data.record_stream(default_stream)
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
|
||||
for group_module in self.modules:
|
||||
@@ -300,35 +245,18 @@ class ModuleGroup:
|
||||
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
if _is_torchao_tensor(buffer):
|
||||
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
|
||||
else:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
for param in self.parameters:
|
||||
if _is_torchao_tensor(param):
|
||||
moved = param.data.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(param, moved)
|
||||
else:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
for buffer in self.buffers:
|
||||
if _is_torchao_tensor(buffer):
|
||||
moved = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(buffer, moved)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
|
||||
@@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_kohya_flux2_lora_to_diffusers(state_dict):
|
||||
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
|
||||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
|
||||
scale = alpha / rank
|
||||
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
|
||||
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
|
||||
|
||||
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
default_alpha = torch.tensor(
|
||||
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
|
||||
)
|
||||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# check if upweight is sparse
|
||||
is_sparse = False
|
||||
if sd_lora_rank % num_splits == 0:
|
||||
ait_rank = sd_lora_rank // num_splits
|
||||
is_sparse = True
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
for k in range(len(dims)):
|
||||
if j == k:
|
||||
continue
|
||||
is_sparse = is_sparse and torch.all(
|
||||
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
|
||||
)
|
||||
i += dims[j]
|
||||
if is_sparse:
|
||||
logger.info(f"weight is sparse: {sds_key}")
|
||||
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
if not is_sparse:
|
||||
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
||||
else:
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
# Detect number of blocks from keys
|
||||
num_double_layers = 0
|
||||
num_single_layers = 0
|
||||
for key in state_dict.keys():
|
||||
if key.startswith("lora_unet_double_blocks_"):
|
||||
block_idx = int(key.split("_")[4])
|
||||
num_double_layers = max(num_double_layers, block_idx + 1)
|
||||
elif key.startswith("lora_unet_single_blocks_"):
|
||||
block_idx = int(key.split("_")[4])
|
||||
num_single_layers = max(num_single_layers, block_idx + 1)
|
||||
|
||||
ait_sd = {}
|
||||
|
||||
for i in range(num_double_layers):
|
||||
# Attention projections
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_out.0",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_add_out",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
# MLP layers (Flux2 uses ff.linear_in/linear_out)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff.linear_in",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff.linear_out",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.linear_in",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.linear_out",
|
||||
)
|
||||
|
||||
for i in range(num_single_layers):
|
||||
# Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear1",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
|
||||
)
|
||||
# Single blocks: linear2 -> attn.to_out
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear2",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_out",
|
||||
)
|
||||
|
||||
# Handle optional extra keys
|
||||
extra_mappings = {
|
||||
"lora_unet_img_in": "transformer.x_embedder",
|
||||
"lora_unet_txt_in": "transformer.context_embedder",
|
||||
"lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
|
||||
"lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
|
||||
"lora_unet_final_layer_linear": "transformer.proj_out",
|
||||
}
|
||||
for sds_key, ait_key in extra_mappings.items():
|
||||
_convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
|
||||
|
||||
remaining_keys = list(state_dict.keys())
|
||||
if remaining_keys:
|
||||
logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
|
||||
|
||||
return ait_sd
|
||||
|
||||
|
||||
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
"""
|
||||
Convert non-diffusers ZImage LoRA state dict to diffusers format.
|
||||
|
||||
@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_fal_kontext_lora_to_diffusers,
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux2_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
_convert_non_diffusers_flux2_lora_to_diffusers,
|
||||
@@ -5673,6 +5674,13 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
||||
if is_kohya:
|
||||
state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict)
|
||||
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
|
||||
if is_peft_format:
|
||||
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
@@ -44,33 +45,13 @@ from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"UNet2DConditionModel": _maybe_expand_lora_scales,
|
||||
"UNetMotionModel": _maybe_expand_lora_scales,
|
||||
"SD3Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTX2TextConnectors": lambda model_cls, weights: weights,
|
||||
}
|
||||
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
|
||||
lambda: (lambda model_cls, weights: weights),
|
||||
{
|
||||
"UNet2DConditionModel": _maybe_expand_lora_scales,
|
||||
"UNetMotionModel": _maybe_expand_lora_scales,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
|
||||
@@ -87,7 +87,14 @@ class HunyuanImageRefinerRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanImageRefinerAttnBlock(nn.Module):
|
||||
|
||||
@@ -87,7 +87,14 @@ class HunyuanVideo15RMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanVideo15AttnBlock(nn.Module):
|
||||
|
||||
@@ -105,7 +105,14 @@ class QwenImageRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class QwenImageUpsample(nn.Upsample):
|
||||
|
||||
@@ -196,7 +196,14 @@ class WanRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class WanUpsample(nn.Upsample):
|
||||
|
||||
@@ -933,6 +933,7 @@ class QwenImageTransformer2DModel(
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
joint_attention_mask = joint_attention_mask[:, None, None, :]
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
|
||||
@@ -788,9 +788,12 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
if all(seq == max_seqlen for seq in item_seqlens):
|
||||
attn_mask = None
|
||||
else:
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
@@ -871,9 +874,12 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
if all(seq == max_seqlen for seq in unified_seqlens):
|
||||
attn_mask = None
|
||||
else:
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
|
||||
@@ -16,22 +16,29 @@ from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import (
|
||||
is_cosmos_guardrail_available,
|
||||
is_torch_xla_available,
|
||||
is_torchvision_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
import torchvision.transforms.functional
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
@@ -29,6 +29,7 @@ from numpy.linalg import norm
|
||||
from packaging import version
|
||||
|
||||
from .constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from .deprecation_utils import deprecate
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
@@ -67,9 +68,11 @@ else:
|
||||
global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(
|
||||
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
|
||||
deprecate(
|
||||
"diffusers.utils.testing_utils",
|
||||
"1.0.0",
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
|
||||
)
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
|
||||
@@ -481,6 +481,8 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
@@ -488,21 +490,31 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
@@ -518,20 +530,26 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
|
||||
@@ -200,7 +200,6 @@ class ContextParallelTesterMixin:
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@@ -286,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_linear(self):
|
||||
super().test_hotswapping_compiled_model_linear()
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self):
|
||||
super().test_hotswapping_compiled_model_both_linear_and_other()
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -182,6 +184,25 @@ class DeprecateTester(unittest.TestCase):
|
||||
assert str(warning.warning) == "This message is better!!!"
|
||||
assert "diffusers/tests/others/test_utils.py" in warning.filename
|
||||
|
||||
def test_deprecate_testing_utils_module(self):
|
||||
import diffusers.utils.testing_utils
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always")
|
||||
importlib.reload(diffusers.utils.testing_utils)
|
||||
|
||||
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
|
||||
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
|
||||
|
||||
messages = [str(w.message) for w in deprecation_warnings]
|
||||
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
|
||||
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
|
||||
)
|
||||
assert any(
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
|
||||
for msg in messages
|
||||
), f"Expected deprecation message substring not found, got: {messages}"
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
|
||||
class ExpectationsTester(unittest.TestCase):
|
||||
|
||||
@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
|
||||
|
||||
|
||||
def fetch_pipeline_objects():
|
||||
models = api.list_models(library="diffusers")
|
||||
models = api.list_models(filter="diffusers")
|
||||
downloads = defaultdict(int)
|
||||
|
||||
for model in models:
|
||||
|
||||
Reference in New Issue
Block a user