mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-21 03:44:49 +08:00
Compare commits
8 Commits
v0.35.0-re
...
v0.35.2-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b71269675e | ||
|
|
36059182f1 | ||
|
|
9169e81609 | ||
|
|
8160289373 | ||
|
|
08782bf3bf | ||
|
|
0f252be0ed | ||
|
|
e3d4a6b070 | ||
|
|
ad00c565b7 |
2
setup.py
2
setup.py
@@ -269,7 +269,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.35.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.35.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.35.0"
|
__version__ = "0.35.2"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|||||||
@@ -110,6 +110,27 @@ if _CAN_USE_XFORMERS_ATTN:
|
|||||||
else:
|
else:
|
||||||
xops = None
|
xops = None
|
||||||
|
|
||||||
|
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
|
||||||
|
if torch.__version__ >= "2.4.0":
|
||||||
|
_custom_op = torch.library.custom_op
|
||||||
|
_register_fake = torch.library.register_fake
|
||||||
|
else:
|
||||||
|
|
||||||
|
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
|
||||||
|
def wrap(func):
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrap if fn is None else fn
|
||||||
|
|
||||||
|
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
|
||||||
|
def wrap(func):
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrap if fn is None else fn
|
||||||
|
|
||||||
|
_custom_op = custom_op_no_op
|
||||||
|
_register_fake = register_fake_no_op
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|||||||
|
|
||||||
# ===== torch op registrations =====
|
# ===== torch op registrations =====
|
||||||
# Registrations are required for fullgraph tracing compatibility
|
# Registrations are required for fullgraph tracing compatibility
|
||||||
|
|
||||||
|
|
||||||
# TODO: library.custom_op and register_fake probably need version guards?
|
|
||||||
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
||||||
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
||||||
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
|
||||||
|
|
||||||
|
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||||
def _wrapped_flash_attn_3_original(
|
def _wrapped_flash_attn_3_original(
|
||||||
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original(
|
|||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
|
@_register_fake("flash_attn_3::_flash_attn_forward")
|
||||||
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
batch_size, seq_len, num_heads, head_dim = query.shape
|
batch_size, seq_len, num_heads, head_dim = query.shape
|
||||||
lse_shape = (batch_size, seq_len, num_heads)
|
lse_shape = (batch_size, seq_len, num_heads)
|
||||||
|
|||||||
@@ -350,7 +350,9 @@ class LTXVideoTransformerBlock(nn.Module):
|
|||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
|
|
||||||
num_ada_params = self.scale_shift_table.shape[0]
|
num_ada_params = self.scale_shift_table.shape[0]
|
||||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||||
|
batch_size, temb.size(1), num_ada_params, -1
|
||||||
|
)
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||||
|
|
||||||
|
|||||||
@@ -665,12 +665,12 @@ class WanTransformer3DModel(
|
|||||||
# 5. Output norm, projection & unpatchify
|
# 5. Output norm, projection & unpatchify
|
||||||
if temb.ndim == 3:
|
if temb.ndim == 3:
|
||||||
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
||||||
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
|
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||||
shift = shift.squeeze(2)
|
shift = shift.squeeze(2)
|
||||||
scale = scale.squeeze(2)
|
scale = scale.squeeze(2)
|
||||||
else:
|
else:
|
||||||
# batch_size, inner_dim
|
# batch_size, inner_dim
|
||||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||||
|
|
||||||
# Move the shift and scale tensors to the same device as hidden_states.
|
# Move the shift and scale tensors to the same device as hidden_states.
|
||||||
# When using multi-GPU inference via accelerate these will be on the
|
# When using multi-GPU inference via accelerate these will be on the
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
|
|||||||
control_hidden_states = control_hidden_states + hidden_states
|
control_hidden_states = control_hidden_states + hidden_states
|
||||||
|
|
||||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||||
self.scale_shift_table + temb.float()
|
self.scale_shift_table.to(temb.device) + temb.float()
|
||||||
).chunk(6, dim=1)
|
).chunk(6, dim=1)
|
||||||
|
|
||||||
# 1. Self-attention
|
# 1. Self-attention
|
||||||
@@ -359,7 +359,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
|||||||
hidden_states = hidden_states + control_hint * scale
|
hidden_states = hidden_states + control_hint * scale
|
||||||
|
|
||||||
# 6. Output norm, projection & unpatchify
|
# 6. Output norm, projection & unpatchify
|
||||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||||
|
|
||||||
# Move the shift and scale tensors to the same device as hidden_states.
|
# Move the shift and scale tensors to the same device as hidden_states.
|
||||||
# When using multi-GPU inference via accelerate these will be on the
|
# When using multi-GPU inference via accelerate these will be on the
|
||||||
|
|||||||
@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
|
|||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||||
|
|
||||||
|
if is_transformers_version("<=", "4.56.2"):
|
||||||
|
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
import accelerate
|
import accelerate
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||||
|
if is_transformers_version("<=", "4.56.2"):
|
||||||
|
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||||
|
|
||||||
# model_pytorch, diffusion_model_pytorch, ...
|
# model_pytorch, diffusion_model_pytorch, ...
|
||||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||||
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||||
|
if is_transformers_version("<=", "4.56.2"):
|
||||||
|
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||||
|
|
||||||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
||||||
|
|
||||||
@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||||
|
if is_transformers_version("<=", "4.56.2"):
|
||||||
|
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||||
|
|
||||||
# model_pytorch, diffusion_model_pytorch, ...
|
# model_pytorch, diffusion_model_pytorch, ...
|
||||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||||
@@ -830,6 +838,9 @@ def load_sub_model(
|
|||||||
else:
|
else:
|
||||||
loading_kwargs["low_cpu_mem_usage"] = False
|
loading_kwargs["low_cpu_mem_usage"] = False
|
||||||
|
|
||||||
|
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
|
||||||
|
loading_kwargs.pop("offload_state_dict")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
quantization_config is not None
|
quantization_config is not None
|
||||||
and isinstance(quantization_config, PipelineQuantizationConfig)
|
and isinstance(quantization_config, PipelineQuantizationConfig)
|
||||||
|
|||||||
@@ -62,25 +62,6 @@ EXAMPLE_DOC_STRING = """
|
|||||||
>>> image.save("qwenimage_edit.png")
|
>>> image.save("qwenimage_edit.png")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
PREFERRED_QWENIMAGE_RESOLUTIONS = [
|
|
||||||
(672, 1568),
|
|
||||||
(688, 1504),
|
|
||||||
(720, 1456),
|
|
||||||
(752, 1392),
|
|
||||||
(800, 1328),
|
|
||||||
(832, 1248),
|
|
||||||
(880, 1184),
|
|
||||||
(944, 1104),
|
|
||||||
(1024, 1024),
|
|
||||||
(1104, 944),
|
|
||||||
(1184, 880),
|
|
||||||
(1248, 832),
|
|
||||||
(1328, 800),
|
|
||||||
(1392, 752),
|
|
||||||
(1456, 720),
|
|
||||||
(1504, 688),
|
|
||||||
(1568, 672),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||||
@@ -565,7 +546,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|||||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
max_sequence_length: int = 512,
|
max_sequence_length: int = 512,
|
||||||
_auto_resize: bool = True,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@@ -646,8 +626,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|||||||
returning a tuple, the first element is a list with the generated images.
|
returning a tuple, the first element is a list with the generated images.
|
||||||
"""
|
"""
|
||||||
image_size = image[0].size if isinstance(image, list) else image.size
|
image_size = image[0].size if isinstance(image, list) else image.size
|
||||||
width, height = image_size
|
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
|
||||||
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
|
|
||||||
height = height or calculated_height
|
height = height or calculated_height
|
||||||
width = width or calculated_width
|
width = width or calculated_width
|
||||||
|
|
||||||
@@ -685,18 +664,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|||||||
device = self._execution_device
|
device = self._execution_device
|
||||||
# 3. Preprocess image
|
# 3. Preprocess image
|
||||||
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||||
img = image[0] if isinstance(image, list) else image
|
image = self.image_processor.resize(image, calculated_height, calculated_width)
|
||||||
image_height, image_width = self.image_processor.get_default_height_width(img)
|
|
||||||
aspect_ratio = image_width / image_height
|
|
||||||
if _auto_resize:
|
|
||||||
_, image_width, image_height = min(
|
|
||||||
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
|
|
||||||
)
|
|
||||||
image_width = image_width // multiple_of * multiple_of
|
|
||||||
image_height = image_height // multiple_of * multiple_of
|
|
||||||
image = self.image_processor.resize(image, image_height, image_width)
|
|
||||||
prompt_image = image
|
prompt_image = image
|
||||||
image = self.image_processor.preprocess(image, image_height, image_width)
|
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
|
||||||
image = image.unsqueeze(2)
|
image = image.unsqueeze(2)
|
||||||
|
|
||||||
has_neg_prompt = negative_prompt is not None or (
|
has_neg_prompt = negative_prompt is not None or (
|
||||||
@@ -713,9 +683,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
)
|
)
|
||||||
if do_true_cfg:
|
if do_true_cfg:
|
||||||
# negative image is the same size as the original image, but all pixels are white
|
|
||||||
# negative_image = Image.new("RGB", (image.width, image.height), (255, 255, 255))
|
|
||||||
|
|
||||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||||
image=prompt_image,
|
image=prompt_image,
|
||||||
prompt=negative_prompt,
|
prompt=negative_prompt,
|
||||||
@@ -742,7 +709,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|||||||
img_shapes = [
|
img_shapes = [
|
||||||
[
|
[
|
||||||
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
||||||
(1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
|
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
|
||||||
]
|
]
|
||||||
] * batch_size
|
] * batch_size
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user