mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-20 23:48:19 +08:00
Compare commits
4 Commits
release-wo
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9855c4204 | ||
|
|
0b35834351 | ||
|
|
522b523e40 | ||
|
|
e9b9f25f67 |
8
.github/workflows/release_tests_fast.yml
vendored
8
.github/workflows/release_tests_fast.yml
vendored
@@ -4,6 +4,7 @@
|
||||
name: (Release) Fast GPU Tests on main
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- "v*.*.*-release"
|
||||
@@ -33,6 +34,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -74,6 +76,7 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -125,6 +128,7 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -175,6 +179,7 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -232,6 +237,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -274,6 +280,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -316,6 +323,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
|
||||
@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_VARLEN_HUB = "flash_varlen_hub"
|
||||
FLASH_4_HUB = "flash_4_hub"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -358,6 +359,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-staging/flash-attn4",
|
||||
function_attr="flash_attn_func",
|
||||
version=0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
raise RuntimeError(
|
||||
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_4_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
return (out[0], out[1]) if return_lse else out[0]
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -324,17 +324,18 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
if hasattr(self.language_model, "_get_initial_cache_position"):
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
|
||||
@@ -28,7 +28,6 @@ from diffusers.utils.import_utils import is_peft_available
|
||||
|
||||
from ..testing_utils import (
|
||||
floats_tensor,
|
||||
is_flaky,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
skip_mps,
|
||||
@@ -46,7 +45,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
@is_flaky(max_attempts=10, description="very flaky class")
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
@@ -73,8 +71,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"base_dim": 3,
|
||||
"z_dim": 4,
|
||||
"dim_mult": [1, 1, 1, 1],
|
||||
"latents_mean": torch.randn(4).numpy().tolist(),
|
||||
"latents_std": torch.randn(4).numpy().tolist(),
|
||||
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
|
||||
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user