mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-02 08:55:01 +08:00
Compare commits
4 Commits
enable-cp-
...
modular-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e77f1c56dc | ||
|
|
372222a4b6 | ||
|
|
1f57b175ae | ||
|
|
581a425130 |
@@ -260,10 +260,6 @@ class _HubKernelConfig:
|
|||||||
function_attr: str
|
function_attr: str
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
kernel_fn: Optional[Callable] = None
|
kernel_fn: Optional[Callable] = None
|
||||||
wrapped_forward_attr: Optional[str] = None
|
|
||||||
wrapped_backward_attr: Optional[str] = None
|
|
||||||
wrapped_forward_fn: Optional[Callable] = None
|
|
||||||
wrapped_backward_fn: Optional[Callable] = None
|
|
||||||
|
|
||||||
|
|
||||||
# Registry for hub-based attention kernels
|
# Registry for hub-based attention kernels
|
||||||
@@ -278,11 +274,7 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
|||||||
# revision="fake-ops-return-probs",
|
# revision="fake-ops-return-probs",
|
||||||
),
|
),
|
||||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||||
repo_id="kernels-community/flash-attn2",
|
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||||
function_attr="flash_attn_func",
|
|
||||||
revision=None,
|
|
||||||
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
|
||||||
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
|
||||||
),
|
),
|
||||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||||
@@ -607,39 +599,22 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|||||||
|
|
||||||
|
|
||||||
# ===== Helpers for downloading kernels =====
|
# ===== Helpers for downloading kernels =====
|
||||||
def _resolve_kernel_attr(module, attr_path: str):
|
|
||||||
target = module
|
|
||||||
for attr in attr_path.split("."):
|
|
||||||
if not hasattr(target, attr):
|
|
||||||
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
|
|
||||||
target = getattr(target, attr)
|
|
||||||
return target
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||||
if backend not in _HUB_KERNELS_REGISTRY:
|
if backend not in _HUB_KERNELS_REGISTRY:
|
||||||
return
|
return
|
||||||
config = _HUB_KERNELS_REGISTRY[backend]
|
config = _HUB_KERNELS_REGISTRY[backend]
|
||||||
|
|
||||||
needs_kernel = config.kernel_fn is None
|
if config.kernel_fn is not None:
|
||||||
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
|
|
||||||
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
|
|
||||||
|
|
||||||
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from kernels import get_kernel
|
from kernels import get_kernel
|
||||||
|
|
||||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||||
if needs_kernel:
|
kernel_func = getattr(kernel_module, config.function_attr)
|
||||||
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
|
|
||||||
|
|
||||||
if needs_wrapped_forward:
|
# Cache the downloaded kernel function in the config object
|
||||||
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
|
config.kernel_fn = kernel_func
|
||||||
|
|
||||||
if needs_wrapped_backward:
|
|
||||||
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||||
@@ -1090,231 +1065,6 @@ def _flash_attention_backward_op(
|
|||||||
return grad_query, grad_key, grad_value
|
return grad_query, grad_key, grad_value
|
||||||
|
|
||||||
|
|
||||||
def _flash_attention_hub_forward_op(
|
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
|
||||||
dropout_p: float = 0.0,
|
|
||||||
is_causal: bool = False,
|
|
||||||
scale: Optional[float] = None,
|
|
||||||
enable_gqa: bool = False,
|
|
||||||
return_lse: bool = False,
|
|
||||||
_save_ctx: bool = True,
|
|
||||||
_parallel_config: Optional["ParallelConfig"] = None,
|
|
||||||
):
|
|
||||||
if attn_mask is not None:
|
|
||||||
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
|
|
||||||
if enable_gqa:
|
|
||||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
|
|
||||||
|
|
||||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
|
||||||
wrapped_forward_fn = config.wrapped_forward_fn
|
|
||||||
wrapped_backward_fn = config.wrapped_backward_fn
|
|
||||||
if wrapped_forward_fn is None or wrapped_backward_fn is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
|
|
||||||
"for context parallel execution."
|
|
||||||
)
|
|
||||||
|
|
||||||
if scale is None:
|
|
||||||
scale = query.shape[-1] ** (-0.5)
|
|
||||||
|
|
||||||
window_size = (-1, -1)
|
|
||||||
softcap = 0.0
|
|
||||||
alibi_slopes = None
|
|
||||||
deterministic = False
|
|
||||||
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
|
||||||
|
|
||||||
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
|
||||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(grad_enabled):
|
|
||||||
out, lse, S_dmask, rng_state = wrapped_forward_fn(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
dropout_p,
|
|
||||||
scale,
|
|
||||||
is_causal,
|
|
||||||
window_size[0],
|
|
||||||
window_size[1],
|
|
||||||
softcap,
|
|
||||||
alibi_slopes,
|
|
||||||
return_lse,
|
|
||||||
)
|
|
||||||
lse = lse.permute(0, 2, 1).contiguous()
|
|
||||||
|
|
||||||
if _save_ctx:
|
|
||||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
|
||||||
ctx.dropout_p = dropout_p
|
|
||||||
ctx.scale = scale
|
|
||||||
ctx.is_causal = is_causal
|
|
||||||
ctx.window_size = window_size
|
|
||||||
ctx.softcap = softcap
|
|
||||||
ctx.alibi_slopes = alibi_slopes
|
|
||||||
ctx.deterministic = deterministic
|
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
|
||||||
|
|
||||||
|
|
||||||
def _flash_attention_hub_backward_op(
|
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
|
||||||
grad_out: torch.Tensor,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
|
||||||
wrapped_backward_fn = config.wrapped_backward_fn
|
|
||||||
if wrapped_backward_fn is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
|
|
||||||
)
|
|
||||||
|
|
||||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
|
||||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
|
||||||
|
|
||||||
_ = wrapped_backward_fn(
|
|
||||||
grad_out,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
out,
|
|
||||||
lse,
|
|
||||||
grad_query,
|
|
||||||
grad_key,
|
|
||||||
grad_value,
|
|
||||||
ctx.dropout_p,
|
|
||||||
ctx.scale,
|
|
||||||
ctx.is_causal,
|
|
||||||
ctx.window_size[0],
|
|
||||||
ctx.window_size[1],
|
|
||||||
ctx.softcap,
|
|
||||||
ctx.alibi_slopes,
|
|
||||||
ctx.deterministic,
|
|
||||||
rng_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
|
||||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
|
||||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
|
||||||
|
|
||||||
return grad_query, grad_key, grad_value
|
|
||||||
|
|
||||||
|
|
||||||
def _flash_attention_3_hub_forward_op(
|
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
|
||||||
dropout_p: float = 0.0,
|
|
||||||
is_causal: bool = False,
|
|
||||||
scale: Optional[float] = None,
|
|
||||||
enable_gqa: bool = False,
|
|
||||||
return_lse: bool = False,
|
|
||||||
_save_ctx: bool = True,
|
|
||||||
_parallel_config: Optional["ParallelConfig"] = None,
|
|
||||||
*,
|
|
||||||
window_size: Tuple[int, int] = (-1, -1),
|
|
||||||
softcap: float = 0.0,
|
|
||||||
num_splits: int = 1,
|
|
||||||
pack_gqa: Optional[bool] = None,
|
|
||||||
deterministic: bool = False,
|
|
||||||
sm_margin: int = 0,
|
|
||||||
):
|
|
||||||
if attn_mask is not None:
|
|
||||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
|
|
||||||
if dropout_p != 0.0:
|
|
||||||
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
|
|
||||||
if enable_gqa:
|
|
||||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
|
||||||
|
|
||||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
|
||||||
out = func(
|
|
||||||
q=query,
|
|
||||||
k=key,
|
|
||||||
v=value,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=is_causal,
|
|
||||||
qv=None,
|
|
||||||
q_descale=None,
|
|
||||||
k_descale=None,
|
|
||||||
v_descale=None,
|
|
||||||
window_size=window_size,
|
|
||||||
softcap=softcap,
|
|
||||||
num_splits=num_splits,
|
|
||||||
pack_gqa=pack_gqa,
|
|
||||||
deterministic=deterministic,
|
|
||||||
sm_margin=sm_margin,
|
|
||||||
return_attn_probs=return_lse,
|
|
||||||
)
|
|
||||||
|
|
||||||
lse = None
|
|
||||||
if return_lse:
|
|
||||||
out, lse = out
|
|
||||||
lse = lse.permute(0, 2, 1).contiguous()
|
|
||||||
|
|
||||||
if _save_ctx:
|
|
||||||
ctx.save_for_backward(query, key, value)
|
|
||||||
ctx.scale = scale
|
|
||||||
ctx.is_causal = is_causal
|
|
||||||
ctx._hub_kernel = func
|
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
|
||||||
|
|
||||||
|
|
||||||
def _flash_attention_3_hub_backward_op(
|
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
|
||||||
grad_out: torch.Tensor,
|
|
||||||
*args,
|
|
||||||
window_size: Tuple[int, int] = (-1, -1),
|
|
||||||
softcap: float = 0.0,
|
|
||||||
num_splits: int = 1,
|
|
||||||
pack_gqa: Optional[bool] = None,
|
|
||||||
deterministic: bool = False,
|
|
||||||
sm_margin: int = 0,
|
|
||||||
):
|
|
||||||
query, key, value = ctx.saved_tensors
|
|
||||||
kernel_fn = ctx._hub_kernel
|
|
||||||
with torch.enable_grad():
|
|
||||||
query_r = query.detach().requires_grad_(True)
|
|
||||||
key_r = key.detach().requires_grad_(True)
|
|
||||||
value_r = value.detach().requires_grad_(True)
|
|
||||||
|
|
||||||
out = kernel_fn(
|
|
||||||
q=query_r,
|
|
||||||
k=key_r,
|
|
||||||
v=value_r,
|
|
||||||
softmax_scale=ctx.scale,
|
|
||||||
causal=ctx.is_causal,
|
|
||||||
qv=None,
|
|
||||||
q_descale=None,
|
|
||||||
k_descale=None,
|
|
||||||
v_descale=None,
|
|
||||||
window_size=window_size,
|
|
||||||
softcap=softcap,
|
|
||||||
num_splits=num_splits,
|
|
||||||
pack_gqa=pack_gqa,
|
|
||||||
deterministic=deterministic,
|
|
||||||
sm_margin=sm_margin,
|
|
||||||
return_attn_probs=False,
|
|
||||||
)
|
|
||||||
if isinstance(out, tuple):
|
|
||||||
out = out[0]
|
|
||||||
|
|
||||||
grad_query, grad_key, grad_value = torch.autograd.grad(
|
|
||||||
out,
|
|
||||||
(query_r, key_r, value_r),
|
|
||||||
grad_out,
|
|
||||||
retain_graph=False,
|
|
||||||
allow_unused=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_query, grad_key, grad_value
|
|
||||||
|
|
||||||
|
|
||||||
def _sage_attention_forward_op(
|
def _sage_attention_forward_op(
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -1353,46 +1103,6 @@ def _sage_attention_forward_op(
|
|||||||
return (out, lse) if return_lse else out
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
|
|
||||||
def _sage_attention_hub_forward_op(
|
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
|
||||||
dropout_p: float = 0.0,
|
|
||||||
is_causal: bool = False,
|
|
||||||
scale: Optional[float] = None,
|
|
||||||
enable_gqa: bool = False,
|
|
||||||
return_lse: bool = False,
|
|
||||||
_save_ctx: bool = True,
|
|
||||||
_parallel_config: Optional["ParallelConfig"] = None,
|
|
||||||
):
|
|
||||||
if attn_mask is not None:
|
|
||||||
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
|
||||||
if dropout_p > 0.0:
|
|
||||||
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
|
||||||
if enable_gqa:
|
|
||||||
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
|
||||||
|
|
||||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
|
||||||
out = func(
|
|
||||||
q=query,
|
|
||||||
k=key,
|
|
||||||
v=value,
|
|
||||||
tensor_layout="NHD",
|
|
||||||
is_causal=is_causal,
|
|
||||||
sm_scale=scale,
|
|
||||||
return_lse=return_lse,
|
|
||||||
)
|
|
||||||
|
|
||||||
lse = None
|
|
||||||
if return_lse:
|
|
||||||
out, lse, *_ = out
|
|
||||||
lse = lse.permute(0, 2, 1).contiguous()
|
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
|
||||||
|
|
||||||
|
|
||||||
def _sage_attention_backward_op(
|
def _sage_attention_backward_op(
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
grad_out: torch.Tensor,
|
grad_out: torch.Tensor,
|
||||||
@@ -1985,7 +1695,7 @@ def _flash_attention(
|
|||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
AttentionBackendName.FLASH_HUB,
|
AttentionBackendName.FLASH_HUB,
|
||||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||||
supports_context_parallel=True,
|
supports_context_parallel=False,
|
||||||
)
|
)
|
||||||
def _flash_attention_hub(
|
def _flash_attention_hub(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -2003,35 +1713,17 @@ def _flash_attention_hub(
|
|||||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||||
|
|
||||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||||
if _parallel_config is None:
|
out = func(
|
||||||
out = func(
|
q=query,
|
||||||
q=query,
|
k=key,
|
||||||
k=key,
|
v=value,
|
||||||
v=value,
|
dropout_p=dropout_p,
|
||||||
dropout_p=dropout_p,
|
softmax_scale=scale,
|
||||||
softmax_scale=scale,
|
causal=is_causal,
|
||||||
causal=is_causal,
|
return_attn_probs=return_lse,
|
||||||
return_attn_probs=return_lse,
|
)
|
||||||
)
|
if return_lse:
|
||||||
if return_lse:
|
out, lse, *_ = out
|
||||||
out, lse, *_ = out
|
|
||||||
else:
|
|
||||||
out = _templated_context_parallel_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
None,
|
|
||||||
dropout_p,
|
|
||||||
is_causal,
|
|
||||||
scale,
|
|
||||||
False,
|
|
||||||
return_lse,
|
|
||||||
forward_op=_flash_attention_hub_forward_op,
|
|
||||||
backward_op=_flash_attention_hub_backward_op,
|
|
||||||
_parallel_config=_parallel_config,
|
|
||||||
)
|
|
||||||
if return_lse:
|
|
||||||
out, lse = out
|
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
@@ -2178,7 +1870,7 @@ def _flash_attention_3(
|
|||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
AttentionBackendName._FLASH_3_HUB,
|
AttentionBackendName._FLASH_3_HUB,
|
||||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||||
supports_context_parallel=True,
|
supports_context_parallel=False,
|
||||||
)
|
)
|
||||||
def _flash_attention_3_hub(
|
def _flash_attention_3_hub(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -2193,68 +1885,33 @@ def _flash_attention_3_hub(
|
|||||||
return_attn_probs: bool = False,
|
return_attn_probs: bool = False,
|
||||||
_parallel_config: Optional["ParallelConfig"] = None,
|
_parallel_config: Optional["ParallelConfig"] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if _parallel_config:
|
||||||
|
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||||
|
|
||||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||||
if _parallel_config is None:
|
out = func(
|
||||||
out = func(
|
q=query,
|
||||||
q=query,
|
k=key,
|
||||||
k=key,
|
v=value,
|
||||||
v=value,
|
softmax_scale=scale,
|
||||||
softmax_scale=scale,
|
causal=is_causal,
|
||||||
causal=is_causal,
|
qv=None,
|
||||||
qv=None,
|
q_descale=None,
|
||||||
q_descale=None,
|
k_descale=None,
|
||||||
k_descale=None,
|
v_descale=None,
|
||||||
v_descale=None,
|
|
||||||
window_size=window_size,
|
|
||||||
softcap=softcap,
|
|
||||||
num_splits=1,
|
|
||||||
pack_gqa=None,
|
|
||||||
deterministic=deterministic,
|
|
||||||
sm_margin=0,
|
|
||||||
return_attn_probs=return_attn_probs,
|
|
||||||
)
|
|
||||||
return (out[0], out[1]) if return_attn_probs else out
|
|
||||||
|
|
||||||
forward_op = functools.partial(
|
|
||||||
_flash_attention_3_hub_forward_op,
|
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
num_splits=1,
|
num_splits=1,
|
||||||
pack_gqa=None,
|
pack_gqa=None,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
sm_margin=0,
|
sm_margin=0,
|
||||||
|
return_attn_probs=return_attn_probs,
|
||||||
)
|
)
|
||||||
backward_op = functools.partial(
|
# When `return_attn_probs` is True, the above returns a tuple of
|
||||||
_flash_attention_3_hub_backward_op,
|
# actual outputs and lse.
|
||||||
window_size=window_size,
|
return (out[0], out[1]) if return_attn_probs else out
|
||||||
softcap=softcap,
|
|
||||||
num_splits=1,
|
|
||||||
pack_gqa=None,
|
|
||||||
deterministic=deterministic,
|
|
||||||
sm_margin=0,
|
|
||||||
)
|
|
||||||
out = _templated_context_parallel_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
is_causal,
|
|
||||||
scale,
|
|
||||||
False,
|
|
||||||
return_attn_probs,
|
|
||||||
forward_op=forward_op,
|
|
||||||
backward_op=backward_op,
|
|
||||||
_parallel_config=_parallel_config,
|
|
||||||
)
|
|
||||||
if return_attn_probs:
|
|
||||||
out, lse = out
|
|
||||||
return out, lse
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
@@ -2885,7 +2542,7 @@ def _sage_attention(
|
|||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
AttentionBackendName.SAGE_HUB,
|
AttentionBackendName.SAGE_HUB,
|
||||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||||
supports_context_parallel=True,
|
supports_context_parallel=False,
|
||||||
)
|
)
|
||||||
def _sage_attention_hub(
|
def _sage_attention_hub(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -2913,23 +2570,6 @@ def _sage_attention_hub(
|
|||||||
)
|
)
|
||||||
if return_lse:
|
if return_lse:
|
||||||
out, lse, *_ = out
|
out, lse, *_ = out
|
||||||
else:
|
|
||||||
out = _templated_context_parallel_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
is_causal,
|
|
||||||
scale,
|
|
||||||
False,
|
|
||||||
return_lse,
|
|
||||||
forward_op=_sage_attention_hub_forward_op,
|
|
||||||
backward_op=_sage_attention_backward_op,
|
|
||||||
_parallel_config=_parallel_config,
|
|
||||||
)
|
|
||||||
if return_lse:
|
|
||||||
out, lse = out
|
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from typing import Optional, Union
|
|||||||
from huggingface_hub.utils import validate_hf_hub_args
|
from huggingface_hub.utils import validate_hf_hub_args
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from ..utils import logging
|
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
|
||||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||||
|
|
||||||
|
|
||||||
@@ -220,4 +220,11 @@ class AutoModel(ConfigMixin):
|
|||||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||||
|
|
||||||
kwargs = {**load_config_kwargs, **kwargs}
|
kwargs = {**load_config_kwargs, **kwargs}
|
||||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||||
|
|
||||||
|
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
|
||||||
|
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
|
||||||
|
load_id = "|".join("null" if p is None else p for p in parts)
|
||||||
|
model._diffusers_load_id = load_id
|
||||||
|
|
||||||
|
return model
|
||||||
|
|||||||
@@ -2142,6 +2142,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
name
|
name
|
||||||
for name in self._component_specs.keys()
|
for name in self._component_specs.keys()
|
||||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||||
|
and self._component_specs[name].pretrained_model_name_or_path is not None
|
||||||
|
and getattr(self, name, None) is None
|
||||||
]
|
]
|
||||||
elif isinstance(names, str):
|
elif isinstance(names, str):
|
||||||
names = [names]
|
names = [names]
|
||||||
|
|||||||
@@ -15,14 +15,14 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||||
from ..utils import is_torch_available, logging
|
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -185,7 +185,7 @@ class ComponentSpec:
|
|||||||
"""
|
"""
|
||||||
Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
|
Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
|
||||||
"""
|
"""
|
||||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
return DIFFUSERS_LOAD_ID_FIELDS.copy()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def load_id(self) -> str:
|
def load_id(self) -> str:
|
||||||
@@ -197,7 +197,7 @@ class ComponentSpec:
|
|||||||
return "null"
|
return "null"
|
||||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||||
parts = ["null" if p is None else p for p in parts]
|
parts = ["null" if p is None else p for p in parts]
|
||||||
return "|".join(p for p in parts if p)
|
return "|".join(parts)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||||
|
|||||||
@@ -482,6 +482,8 @@ class ChromaInpaintPipeline(
|
|||||||
negative_prompt=None,
|
negative_prompt=None,
|
||||||
prompt_embeds=None,
|
prompt_embeds=None,
|
||||||
negative_prompt_embeds=None,
|
negative_prompt_embeds=None,
|
||||||
|
pooled_prompt_embeds=None,
|
||||||
|
negative_pooled_prompt_embeds=None,
|
||||||
callback_on_step_end_tensor_inputs=None,
|
callback_on_step_end_tensor_inputs=None,
|
||||||
padding_mask_crop=None,
|
padding_mask_crop=None,
|
||||||
max_sequence_length=None,
|
max_sequence_length=None,
|
||||||
@@ -529,6 +531,15 @@ class ChromaInpaintPipeline(
|
|||||||
f" {negative_prompt_embeds.shape}."
|
f" {negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||||
|
raise ValueError(
|
||||||
|
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||||
|
)
|
||||||
|
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||||
|
raise ValueError(
|
||||||
|
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||||
|
)
|
||||||
|
|
||||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||||
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
||||||
|
|
||||||
@@ -782,11 +793,13 @@ class ChromaInpaintPipeline(
|
|||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
|||||||
@@ -281,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps (`int`, *optional*):
|
num_inference_steps (`int`):
|
||||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||||
`timesteps` must be `None`.
|
`timesteps` must be `None`.
|
||||||
device (`str` or `torch.device`, *optional*):
|
device (`str` or `torch.device`, *optional*):
|
||||||
@@ -646,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.config.num_train_timesteps
|
return self.config.num_train_timesteps
|
||||||
|
|
||||||
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
|
def previous_timestep(self, timestep: int) -> int:
|
||||||
"""
|
"""
|
||||||
Compute the previous timestep in the diffusion chain.
|
Compute the previous timestep in the diffusion chain.
|
||||||
|
|
||||||
@@ -655,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
The current timestep.
|
The current timestep.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`int` or `torch.Tensor`:
|
`int`:
|
||||||
The previous timestep.
|
The previous timestep.
|
||||||
"""
|
"""
|
||||||
if self.custom_timesteps or self.num_inference_steps:
|
if self.custom_timesteps or self.num_inference_steps:
|
||||||
|
|||||||
@@ -149,41 +149,38 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
For more details, see the original paper: https://huggingface.co/papers/2006.11239
|
For more details, see the original paper: https://huggingface.co/papers/2006.11239
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_train_timesteps (`int`, defaults to 1000):
|
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||||
The number of diffusion steps to train the model.
|
beta_start (`float`): the starting `beta` value of inference.
|
||||||
beta_start (`float`, defaults to 0.0001):
|
beta_end (`float`): the final `beta` value.
|
||||||
The starting `beta` value of inference.
|
beta_schedule (`str`):
|
||||||
beta_end (`float`, defaults to 0.02):
|
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||||
The final `beta` value.
|
|
||||||
beta_schedule (`str`, defaults to `"linear"`):
|
|
||||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
|
||||||
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
|
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
|
||||||
trained_betas (`np.ndarray`, *optional*):
|
trained_betas (`np.ndarray`, optional):
|
||||||
Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||||
variance_type (`str`, defaults to `"fixed_small"`):
|
variance_type (`str`):
|
||||||
Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||||
clip_sample (`bool`, defaults to `True`):
|
clip_sample (`bool`, default `True`):
|
||||||
Option to clip predicted sample for numerical stability.
|
option to clip predicted sample for numerical stability.
|
||||||
prediction_type (`str`, defaults to `"epsilon"`):
|
clip_sample_range (`float`, default `1.0`):
|
||||||
Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||||
|
prediction_type (`str`, default `epsilon`, optional):
|
||||||
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||||
https://huggingface.co/papers/2210.02303)
|
https://huggingface.co/papers/2210.02303)
|
||||||
thresholding (`bool`, defaults to `False`):
|
thresholding (`bool`, default `False`):
|
||||||
Whether to use the "dynamic thresholding" method (introduced by Imagen,
|
whether to use the "dynamic thresholding" method (introduced by Imagen,
|
||||||
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
|
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
|
||||||
diffusion models (such as stable-diffusion).
|
diffusion models (such as stable-diffusion).
|
||||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||||
The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||||
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
|
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
|
||||||
clip_sample_range (`float`, defaults to 1.0):
|
sample_max_value (`float`, default `1.0`):
|
||||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||||
sample_max_value (`float`, defaults to 1.0):
|
timestep_spacing (`str`, default `"leading"`):
|
||||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
|
||||||
timestep_spacing (`str`, defaults to `"leading"`):
|
|
||||||
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
||||||
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||||
steps_offset (`int`, defaults to 0):
|
steps_offset (`int`, default `0`):
|
||||||
An offset added to the inference steps, as required by some model families.
|
An offset added to the inference steps, as required by some model families.
|
||||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||||
@@ -296,7 +293,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps (`int`, *optional*):
|
num_inference_steps (`int`):
|
||||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||||
`timesteps` must be `None`.
|
`timesteps` must be `None`.
|
||||||
device (`str` or `torch.device`, *optional*):
|
device (`str` or `torch.device`, *optional*):
|
||||||
@@ -481,7 +478,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
timestep: int,
|
timestep: int,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator=None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
|
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
@@ -493,8 +490,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||||
sample (`torch.Tensor`):
|
sample (`torch.Tensor`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
generator (`torch.Generator`, *optional*):
|
generator: random number generator.
|
||||||
Random number generator.
|
|
||||||
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -507,10 +503,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
prev_t = self.previous_timestep(t)
|
prev_t = self.previous_timestep(t)
|
||||||
|
|
||||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||||
"learned",
|
|
||||||
"learned_range",
|
|
||||||
]:
|
|
||||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||||
else:
|
else:
|
||||||
predicted_variance = None
|
predicted_variance = None
|
||||||
@@ -559,10 +552,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
if t > 0:
|
if t > 0:
|
||||||
device = model_output.device
|
device = model_output.device
|
||||||
variance_noise = randn_tensor(
|
variance_noise = randn_tensor(
|
||||||
model_output.shape,
|
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||||
generator=generator,
|
|
||||||
device=device,
|
|
||||||
dtype=model_output.dtype,
|
|
||||||
)
|
)
|
||||||
if self.variance_type == "fixed_small_log":
|
if self.variance_type == "fixed_small_log":
|
||||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
||||||
@@ -585,7 +575,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def batch_step_no_noise(
|
def batch_step_no_noise(
|
||||||
self,
|
self,
|
||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
timesteps: torch.Tensor,
|
timesteps: List[int],
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -598,8 +588,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
||||||
timesteps (`torch.Tensor`):
|
timesteps (`List[int]`):
|
||||||
Current discrete timesteps in the diffusion chain. This is a tensor of integers.
|
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
||||||
sample (`torch.Tensor`):
|
sample (`torch.Tensor`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
|
|
||||||
@@ -613,10 +603,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
t = t.view(-1, *([1] * (model_output.ndim - 1)))
|
t = t.view(-1, *([1] * (model_output.ndim - 1)))
|
||||||
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
|
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
|
||||||
|
|
||||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||||
"learned",
|
|
||||||
"learned_range",
|
|
||||||
]:
|
|
||||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@@ -747,7 +734,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return self.config.num_train_timesteps
|
return self.config.num_train_timesteps
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||||
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
|
def previous_timestep(self, timestep):
|
||||||
"""
|
"""
|
||||||
Compute the previous timestep in the diffusion chain.
|
Compute the previous timestep in the diffusion chain.
|
||||||
|
|
||||||
@@ -756,7 +743,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
The current timestep.
|
The current timestep.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`int` or `torch.Tensor`:
|
`int`:
|
||||||
The previous timestep.
|
The previous timestep.
|
||||||
"""
|
"""
|
||||||
if self.custom_timesteps or self.num_inference_steps:
|
if self.custom_timesteps or self.num_inference_steps:
|
||||||
|
|||||||
@@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
The current timestep.
|
The current timestep.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`int` or `torch.Tensor`:
|
`int`:
|
||||||
The previous timestep.
|
The previous timestep.
|
||||||
"""
|
"""
|
||||||
if self.custom_timesteps or self.num_inference_steps:
|
if self.custom_timesteps or self.num_inference_steps:
|
||||||
|
|||||||
@@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
The current timestep.
|
The current timestep.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`int` or `torch.Tensor`:
|
`int`:
|
||||||
The previous timestep.
|
The previous timestep.
|
||||||
"""
|
"""
|
||||||
if self.custom_timesteps or self.num_inference_steps:
|
if self.custom_timesteps or self.num_inference_steps:
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from .constants import (
|
|||||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
|
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
|
||||||
DEPRECATED_REVISION_ARGS,
|
DEPRECATED_REVISION_ARGS,
|
||||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||||
|
DIFFUSERS_LOAD_ID_FIELDS,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
GGUF_FILE_EXTENSION,
|
GGUF_FILE_EXTENSION,
|
||||||
HF_ENABLE_PARALLEL_LOADING,
|
HF_ENABLE_PARALLEL_LOADING,
|
||||||
|
|||||||
@@ -73,3 +73,11 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
|
|||||||
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
|
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||||
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
|
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||||
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
|
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||||
|
|
||||||
|
|
||||||
|
DIFFUSERS_LOAD_ID_FIELDS = [
|
||||||
|
"pretrained_model_name_or_path",
|
||||||
|
"subfolder",
|
||||||
|
"variant",
|
||||||
|
"revision",
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user