Compare commits

..

4 Commits

Author SHA1 Message Date
yiyixuxu
e77f1c56dc by default, skip loading the componeneets does not have the repo id 2026-01-24 05:56:09 +01:00
yiyixuxu
372222a4b6 load_components by default only load components that are not already loaded 2026-01-24 05:42:51 +01:00
yiyixuxu
1f57b175ae style 2026-01-24 03:50:01 +01:00
yiyixuxu
581a425130 tag loader_id from Automodel 2026-01-24 03:49:29 +01:00
11 changed files with 111 additions and 453 deletions

View File

@@ -260,10 +260,6 @@ class _HubKernelConfig:
function_attr: str function_attr: str
revision: Optional[str] = None revision: Optional[str] = None
kernel_fn: Optional[Callable] = None kernel_fn: Optional[Callable] = None
wrapped_forward_attr: Optional[str] = None
wrapped_backward_attr: Optional[str] = None
wrapped_forward_fn: Optional[Callable] = None
wrapped_backward_fn: Optional[Callable] = None
# Registry for hub-based attention kernels # Registry for hub-based attention kernels
@@ -278,11 +274,7 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs", # revision="fake-ops-return-probs",
), ),
AttentionBackendName.FLASH_HUB: _HubKernelConfig( AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
), ),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -607,39 +599,22 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels ===== # ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY: if backend not in _HUB_KERNELS_REGISTRY:
return return
config = _HUB_KERNELS_REGISTRY[backend] config = _HUB_KERNELS_REGISTRY[backend]
needs_kernel = config.kernel_fn is None if config.kernel_fn is not None:
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
return return
try: try:
from kernels import get_kernel from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision) kernel_module = get_kernel(config.repo_id, revision=config.revision)
if needs_kernel: kernel_func = getattr(kernel_module, config.function_attr)
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
if needs_wrapped_forward: # Cache the downloaded kernel function in the config object
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr) config.kernel_fn = kernel_func
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
except Exception as e: except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1090,231 +1065,6 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
*,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op( def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx, ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor, query: torch.Tensor,
@@ -1353,46 +1103,6 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op( def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx, ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor, grad_out: torch.Tensor,
@@ -1985,7 +1695,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB, AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True, supports_context_parallel=False,
) )
def _flash_attention_hub( def _flash_attention_hub(
query: torch.Tensor, query: torch.Tensor,
@@ -2003,35 +1713,17 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.") raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
if _parallel_config is None: out = func(
out = func( q=query,
q=query, k=key,
k=key, v=value,
v=value, dropout_p=dropout_p,
dropout_p=dropout_p, softmax_scale=scale,
softmax_scale=scale, causal=is_causal,
causal=is_causal, return_attn_probs=return_lse,
return_attn_probs=return_lse, )
) if return_lse:
if return_lse: out, lse, *_ = out
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out return (out, lse) if return_lse else out
@@ -2178,7 +1870,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB, AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True, supports_context_parallel=False,
) )
def _flash_attention_3_hub( def _flash_attention_3_hub(
query: torch.Tensor, query: torch.Tensor,
@@ -2193,68 +1885,33 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False, return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None, _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None: if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.") raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
if _parallel_config is None: out = func(
out = func( q=query,
q=query, k=key,
k=key, v=value,
v=value, softmax_scale=scale,
softmax_scale=scale, causal=is_causal,
causal=is_causal, qv=None,
qv=None, q_descale=None,
q_descale=None, k_descale=None,
k_descale=None, v_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
window_size=window_size, window_size=window_size,
softcap=softcap, softcap=softcap,
num_splits=1, num_splits=1,
pack_gqa=None, pack_gqa=None,
deterministic=deterministic, deterministic=deterministic,
sm_margin=0, sm_margin=0,
return_attn_probs=return_attn_probs,
) )
backward_op = functools.partial( # When `return_attn_probs` is True, the above returns a tuple of
_flash_attention_3_hub_backward_op, # actual outputs and lse.
window_size=window_size, return (out[0], out[1]) if return_attn_probs else out
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
@@ -2885,7 +2542,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB, AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True, supports_context_parallel=False,
) )
def _sage_attention_hub( def _sage_attention_hub(
query: torch.Tensor, query: torch.Tensor,
@@ -2913,23 +2570,6 @@ def _sage_attention_hub(
) )
if return_lse: if return_lse:
out, lse, *_ = out out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out return (out, lse) if return_lse else out

View File

@@ -18,7 +18,7 @@ from typing import Optional, Union
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..utils import logging from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
@@ -220,4 +220,11 @@ class AutoModel(ConfigMixin):
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
kwargs = {**load_config_kwargs, **kwargs} kwargs = {**load_config_kwargs, **kwargs}
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs) model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
load_id = "|".join("null" if p is None else p for p in parts)
model._diffusers_load_id = load_id
return model

View File

@@ -2142,6 +2142,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
name name
for name in self._component_specs.keys() for name in self._component_specs.keys()
if self._component_specs[name].default_creation_method == "from_pretrained" if self._component_specs[name].default_creation_method == "from_pretrained"
and self._component_specs[name].pretrained_model_name_or_path is not None
and getattr(self, name, None) is None
] ]
elif isinstance(names, str): elif isinstance(names, str):
names = [names] names = [names]

View File

@@ -15,14 +15,14 @@
import inspect import inspect
import re import re
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Type, Union from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
from ..configuration_utils import ConfigMixin, FrozenDict from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import is_torch_available, logging from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
if is_torch_available(): if is_torch_available():
@@ -185,7 +185,7 @@ class ComponentSpec:
""" """
Return the names of all loadingrelated fields (i.e. those whose field.metadata["loading"] is True). Return the names of all loadingrelated fields (i.e. those whose field.metadata["loading"] is True).
""" """
return [f.name for f in fields(cls) if f.metadata.get("loading", False)] return DIFFUSERS_LOAD_ID_FIELDS.copy()
@property @property
def load_id(self) -> str: def load_id(self) -> str:
@@ -197,7 +197,7 @@ class ComponentSpec:
return "null" return "null"
parts = [getattr(self, k) for k in self.loading_fields()] parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts] parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p) return "|".join(parts)
@classmethod @classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:

View File

@@ -482,6 +482,8 @@ class ChromaInpaintPipeline(
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None, padding_mask_crop=None,
max_sequence_length=None, max_sequence_length=None,
@@ -529,6 +531,15 @@ class ChromaInpaintPipeline(
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if prompt_embeds is not None and prompt_attention_mask is None: if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
@@ -782,11 +793,13 @@ class ChromaInpaintPipeline(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,

View File

@@ -281,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`, *optional*): num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`. `timesteps` must be `None`.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
@@ -646,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __len__(self) -> int: def __len__(self) -> int:
return self.config.num_train_timesteps return self.config.num_train_timesteps
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]: def previous_timestep(self, timestep: int) -> int:
""" """
Compute the previous timestep in the diffusion chain. Compute the previous timestep in the diffusion chain.
@@ -655,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
The current timestep. The current timestep.
Returns: Returns:
`int` or `torch.Tensor`: `int`:
The previous timestep. The previous timestep.
""" """
if self.custom_timesteps or self.num_inference_steps: if self.custom_timesteps or self.num_inference_steps:

View File

@@ -149,41 +149,38 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
For more details, see the original paper: https://huggingface.co/papers/2006.11239 For more details, see the original paper: https://huggingface.co/papers/2006.11239
Args: Args:
num_train_timesteps (`int`, defaults to 1000): num_train_timesteps (`int`): number of diffusion steps used to train the model.
The number of diffusion steps to train the model. beta_start (`float`): the starting `beta` value of inference.
beta_start (`float`, defaults to 0.0001): beta_end (`float`): the final `beta` value.
The starting `beta` value of inference. beta_schedule (`str`):
beta_end (`float`, defaults to 0.02): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray`, optional):
Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`, defaults to `"fixed_small"`): variance_type (`str`):
Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, defaults to `True`): clip_sample (`bool`, default `True`):
Option to clip predicted sample for numerical stability. option to clip predicted sample for numerical stability.
prediction_type (`str`, defaults to `"epsilon"`): clip_sample_range (`float`, default `1.0`):
Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://huggingface.co/papers/2210.02303) https://huggingface.co/papers/2210.02303)
thresholding (`bool`, defaults to `False`): thresholding (`bool`, default `False`):
Whether to use the "dynamic thresholding" method (introduced by Imagen, whether to use the "dynamic thresholding" method (introduced by Imagen,
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
diffusion models (such as stable-diffusion). diffusion models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, defaults to 0.995): dynamic_thresholding_ratio (`float`, default `0.995`):
The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`. (https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
clip_sample_range (`float`, defaults to 1.0): sample_max_value (`float`, default `1.0`):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0): timestep_spacing (`str`, default `"leading"`):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): steps_offset (`int`, default `0`):
An offset added to the inference steps, as required by some model families. An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`): rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
@@ -296,7 +293,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args: Args:
num_inference_steps (`int`, *optional*): num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`. `timesteps` must be `None`.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
@@ -481,7 +478,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: int,
sample: torch.Tensor, sample: torch.Tensor,
generator: Optional[torch.Generator] = None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDPMParallelSchedulerOutput, Tuple]: ) -> Union[DDPMParallelSchedulerOutput, Tuple]:
""" """
@@ -493,8 +490,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
generator (`torch.Generator`, *optional*): generator: random number generator.
Random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
Returns: Returns:
@@ -507,10 +503,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
prev_t = self.previous_timestep(t) prev_t = self.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
"learned",
"learned_range",
]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else: else:
predicted_variance = None predicted_variance = None
@@ -559,10 +552,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
if t > 0: if t > 0:
device = model_output.device device = model_output.device
variance_noise = randn_tensor( variance_noise = randn_tensor(
model_output.shape, model_output.shape, generator=generator, device=device, dtype=model_output.dtype
generator=generator,
device=device,
dtype=model_output.dtype,
) )
if self.variance_type == "fixed_small_log": if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
@@ -585,7 +575,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise( def batch_step_no_noise(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timesteps: torch.Tensor, timesteps: List[int],
sample: torch.Tensor, sample: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@@ -598,8 +588,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.Tensor`): direct output from learned diffusion model. model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`torch.Tensor`): timesteps (`List[int]`):
Current discrete timesteps in the diffusion chain. This is a tensor of integers. current discrete timesteps in the diffusion chain. This is now a list of integers.
sample (`torch.Tensor`): sample (`torch.Tensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
@@ -613,10 +603,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
t = t.view(-1, *([1] * (model_output.ndim - 1))) t = t.view(-1, *([1] * (model_output.ndim - 1)))
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
"learned",
"learned_range",
]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else: else:
pass pass
@@ -747,7 +734,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
return self.config.num_train_timesteps return self.config.num_train_timesteps
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]: def previous_timestep(self, timestep):
""" """
Compute the previous timestep in the diffusion chain. Compute the previous timestep in the diffusion chain.
@@ -756,7 +743,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
The current timestep. The current timestep.
Returns: Returns:
`int` or `torch.Tensor`: `int`:
The previous timestep. The previous timestep.
""" """
if self.custom_timesteps or self.num_inference_steps: if self.custom_timesteps or self.num_inference_steps:

View File

@@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
The current timestep. The current timestep.
Returns: Returns:
`int` or `torch.Tensor`: `int`:
The previous timestep. The previous timestep.
""" """
if self.custom_timesteps or self.num_inference_steps: if self.custom_timesteps or self.num_inference_steps:

View File

@@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
The current timestep. The current timestep.
Returns: Returns:
`int` or `torch.Tensor`: `int`:
The previous timestep. The previous timestep.
""" """
if self.custom_timesteps or self.num_inference_steps: if self.custom_timesteps or self.num_inference_steps:

View File

@@ -23,6 +23,7 @@ from .constants import (
DEFAULT_HF_PARALLEL_LOADING_WORKERS, DEFAULT_HF_PARALLEL_LOADING_WORKERS,
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME, DIFFUSERS_DYNAMIC_MODULE_NAME,
DIFFUSERS_LOAD_ID_FIELDS,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION, GGUF_FILE_EXTENSION,
HF_ENABLE_PARALLEL_LOADING, HF_ENABLE_PARALLEL_LOADING,

View File

@@ -73,3 +73,11 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
DIFFUSERS_LOAD_ID_FIELDS = [
"pretrained_model_name_or_path",
"subfolder",
"variant",
"revision",
]