mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 23:15:00 +08:00
Compare commits
14 Commits
cuda-128-g
...
enable-cp-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dab372dd27 | ||
|
|
2af7baa040 | ||
|
|
a7cb14efbe | ||
|
|
e8e88ff2ce | ||
|
|
6e24cd842c | ||
|
|
981eb802c6 | ||
|
|
1eb40c6dbd | ||
|
|
79438572e0 | ||
|
|
2268583f39 | ||
|
|
dfbd4857b2 | ||
|
|
9bd83616bf | ||
|
|
f732ff1144 | ||
|
|
7a8f85b047 | ||
|
|
82d20e64a5 |
@@ -260,6 +260,10 @@ class _HubKernelConfig:
|
||||
function_attr: str
|
||||
revision: Optional[str] = None
|
||||
kernel_fn: Optional[Callable] = None
|
||||
wrapped_forward_attr: Optional[str] = None
|
||||
wrapped_backward_attr: Optional[str] = None
|
||||
wrapped_forward_fn: Optional[Callable] = None
|
||||
wrapped_backward_fn: Optional[Callable] = None
|
||||
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
@@ -274,7 +278,11 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# revision="fake-ops-return-probs",
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
repo_id="kernels-community/flash-attn2",
|
||||
function_attr="flash_attn_func",
|
||||
revision=None,
|
||||
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
||||
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
||||
),
|
||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||
@@ -599,22 +607,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
|
||||
|
||||
# ===== Helpers for downloading kernels =====
|
||||
def _resolve_kernel_attr(module, attr_path: str):
|
||||
target = module
|
||||
for attr in attr_path.split("."):
|
||||
if not hasattr(target, attr):
|
||||
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
|
||||
target = getattr(target, attr)
|
||||
return target
|
||||
|
||||
|
||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
if backend not in _HUB_KERNELS_REGISTRY:
|
||||
return
|
||||
config = _HUB_KERNELS_REGISTRY[backend]
|
||||
|
||||
if config.kernel_fn is not None:
|
||||
needs_kernel = config.kernel_fn is None
|
||||
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
|
||||
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
|
||||
|
||||
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
|
||||
return
|
||||
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||
kernel_func = getattr(kernel_module, config.function_attr)
|
||||
if needs_kernel:
|
||||
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
|
||||
|
||||
# Cache the downloaded kernel function in the config object
|
||||
config.kernel_fn = kernel_func
|
||||
if needs_wrapped_forward:
|
||||
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
|
||||
|
||||
if needs_wrapped_backward:
|
||||
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||
@@ -1065,6 +1090,231 @@ def _flash_attention_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
|
||||
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||
wrapped_forward_fn = config.wrapped_forward_fn
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_forward_fn is None or wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
|
||||
"for context parallel execution."
|
||||
)
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
||||
|
||||
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
out, lse, S_dmask, rng_state = wrapped_forward_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
|
||||
)
|
||||
|
||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||
|
||||
_ = wrapped_backward_fn(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.dropout_p,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state,
|
||||
)
|
||||
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_3_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
*,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
|
||||
if dropout_p != 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx._hub_kernel = func
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_3_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
):
|
||||
query, key, value = ctx.saved_tensors
|
||||
kernel_fn = ctx._hub_kernel
|
||||
with torch.enable_grad():
|
||||
query_r = query.detach().requires_grad_(True)
|
||||
key_r = key.detach().requires_grad_(True)
|
||||
value_r = value.detach().requires_grad_(True)
|
||||
|
||||
out = kernel_fn(
|
||||
q=query_r,
|
||||
k=key_r,
|
||||
v=value_r,
|
||||
softmax_scale=ctx.scale,
|
||||
causal=ctx.is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=False,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
|
||||
grad_query, grad_key, grad_value = torch.autograd.grad(
|
||||
out,
|
||||
(query_r, key_r, value_r),
|
||||
grad_out,
|
||||
retain_graph=False,
|
||||
allow_unused=False,
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _sage_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1103,6 +1353,46 @@ def _sage_attention_forward_op(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
||||
if dropout_p > 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
@@ -1695,7 +1985,7 @@ def _flash_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -1713,17 +2003,35 @@ def _flash_attention_hub(
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_flash_attention_hub_forward_op,
|
||||
backward_op=_flash_attention_hub_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@@ -1870,7 +2178,7 @@ def _flash_attention_3(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -1885,33 +2193,68 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
return_attn_probs=return_attn_probs,
|
||||
)
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
|
||||
forward_op = functools.partial(
|
||||
_flash_attention_3_hub_forward_op,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
return_attn_probs=return_attn_probs,
|
||||
)
|
||||
# When `return_attn_probs` is True, the above returns a tuple of
|
||||
# actual outputs and lse.
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
backward_op = functools.partial(
|
||||
_flash_attention_3_hub_backward_op,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
)
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_attn_probs,
|
||||
forward_op=forward_op,
|
||||
backward_op=backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_attn_probs:
|
||||
out, lse = out
|
||||
return out, lse
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
@@ -2542,7 +2885,7 @@ def _sage_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -2570,6 +2913,23 @@ def _sage_attention_hub(
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_sage_attention_hub_forward_op,
|
||||
backward_op=_sage_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
||||
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
|
||||
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
|
||||
# Issue: https://github.com/huggingface/diffusers/issues/12975
|
||||
if self.training:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
@@ -400,6 +400,7 @@ class LongCatImageTransformer2DModel(
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Longcat-Image.
|
||||
|
||||
@@ -482,8 +482,6 @@ class ChromaInpaintPipeline(
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
max_sequence_length=None,
|
||||
@@ -531,15 +529,6 @@ class ChromaInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
||||
|
||||
@@ -793,13 +782,11 @@ class ChromaInpaintPipeline(
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
||||
@@ -101,7 +101,7 @@ def betas_for_alpha_bar(
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
@@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep=None):
|
||||
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
|
||||
if prev_timestep is None:
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
@@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def _batch_get_variance(self, t, prev_t):
|
||||
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
|
||||
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
|
||||
@@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
|
||||
@@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
||||
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
||||
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
||||
generator: random number generator.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
|
||||
correction is necessary because the predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
|
||||
the input and `use_clipped_model_output` has no effect.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random number generator.
|
||||
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
|
||||
can directly provide the noise for the variance itself. This is useful for methods such as
|
||||
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
|
||||
@@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if variance_noise is None:
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
variance = std_dev_t * variance_noise
|
||||
|
||||
@@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
def batch_step_no_noise(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timesteps: List[int],
|
||||
timesteps: torch.Tensor,
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
@@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
||||
timesteps (`List[int]`):
|
||||
timesteps (`torch.Tensor`):
|
||||
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
@@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -281,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
@@ -646,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def previous_timestep(self, timestep: int) -> int:
|
||||
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
|
||||
"""
|
||||
Compute the previous timestep in the diffusion chain.
|
||||
|
||||
@@ -655,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
`int` or `torch.Tensor`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -22,6 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDPMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
|
||||
num_inference_steps: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
|
||||
|
||||
|
||||
@@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
|
||||
@@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: DDPMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -149,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
For more details, see the original paper: https://huggingface.co/papers/2006.11239
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
variance_type (`str`):
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
variance_type (`str`, defaults to `"fixed_small"`):
|
||||
Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
Option to clip predicted sample for numerical stability.
|
||||
prediction_type (`str`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://huggingface.co/papers/2210.02303)
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen,
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method (introduced by Imagen,
|
||||
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
|
||||
diffusion models (such as stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, default `"leading"`):
|
||||
clip_sample_range (`float`, defaults to 1.0):
|
||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
||||
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, default `0`):
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
@@ -293,7 +296,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
@@ -478,7 +481,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -490,7 +493,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator: random number generator.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
@@ -503,7 +507,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
||||
"learned",
|
||||
"learned_range",
|
||||
]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
@@ -552,7 +559,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
if t > 0:
|
||||
device = model_output.device
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
if self.variance_type == "fixed_small_log":
|
||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
||||
@@ -575,7 +585,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
def batch_step_no_noise(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timesteps: List[int],
|
||||
timesteps: torch.Tensor,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -588,8 +598,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
||||
timesteps (`List[int]`):
|
||||
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
||||
timesteps (`torch.Tensor`):
|
||||
Current discrete timesteps in the diffusion chain. This is a tensor of integers.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
@@ -603,7 +613,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.view(-1, *([1] * (model_output.ndim - 1)))
|
||||
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
|
||||
"learned",
|
||||
"learned_range",
|
||||
]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
pass
|
||||
@@ -734,7 +747,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep):
|
||||
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
|
||||
"""
|
||||
Compute the previous timestep in the diffusion chain.
|
||||
|
||||
@@ -743,7 +756,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
`int` or `torch.Tensor`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
`int` or `torch.Tensor`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current timestep.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
`int` or `torch.Tensor`:
|
||||
The previous timestep.
|
||||
"""
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
|
||||
@@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=5e-1)
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
|
||||
|
||||
@is_flaky()
|
||||
def test_model_cpu_offload_forward_pass(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=8e-4)
|
||||
|
||||
@@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user