mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-20 18:00:46 +08:00
Compare commits
27 Commits
modular-sa
...
requiremen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f274df4fef | ||
|
|
2504341a20 | ||
|
|
e8d4612a25 | ||
|
|
29273538d1 | ||
|
|
445c42eb82 | ||
|
|
79fa0e2bd5 | ||
|
|
60e3284003 | ||
|
|
7b43d0e409 | ||
|
|
3879e32254 | ||
|
|
a88d11bc90 | ||
|
|
a9165eb749 | ||
|
|
eeb3445444 | ||
|
|
5b7d0dfab6 | ||
|
|
1de4402c26 | ||
|
|
024c2b9839 | ||
|
|
35d8d97c02 | ||
|
|
e52cabeff2 | ||
|
|
2c4d73d72d | ||
|
|
046be83946 | ||
|
|
b7fba892f5 | ||
|
|
ecbd907e76 | ||
|
|
d159ae025d | ||
|
|
756a1567f5 | ||
|
|
d2731ababa | ||
|
|
37d3887194 | ||
|
|
127e9a39d8 | ||
|
|
12ceecf077 |
2
.github/workflows/pr_modular_tests.yml
vendored
2
.github/workflows/pr_modular_tests.yml
vendored
@@ -117,7 +117,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -e ".[quality,test]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
6
.github/workflows/pr_tests.yml
vendored
6
.github/workflows/pr_tests.yml
vendored
@@ -114,7 +114,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -e ".[quality,test]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
@@ -191,7 +191,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -e ".[quality,test]"
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -242,7 +242,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -e ".[quality,test]"
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
uv pip install -U tokenizers
|
||||
|
||||
5
.github/workflows/pr_tests_gpu.yml
vendored
5
.github/workflows/pr_tests_gpu.yml
vendored
@@ -199,6 +199,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
|
||||
uv pip install pip==25.2 setuptools==80.10.2
|
||||
uv pip install --no-build-isolation k-diffusion==0.0.12
|
||||
uv pip install --upgrade pip setuptools
|
||||
# Install the rest as normal
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
5
.github/workflows/push_tests.yml
vendored
5
.github/workflows/push_tests.yml
vendored
@@ -126,6 +126,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
|
||||
uv pip install pip==25.2 setuptools==80.10.2
|
||||
uv pip install --no-build-isolation k-diffusion==0.0.12
|
||||
uv pip install --upgrade pip setuptools
|
||||
# Install the rest as normal
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
2
.github/workflows/push_tests_mps.yml
vendored
2
.github/workflows/push_tests_mps.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
${CONDA_RUN} python -m uv pip install -e ".[quality]"
|
||||
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
|
||||
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
|
||||
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
${CONDA_RUN} python -m uv pip install transformers --upgrade
|
||||
|
||||
@@ -29,7 +29,7 @@ Qwen-Image comes in the following variants:
|
||||
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
|
||||
|
||||
> [!TIP]
|
||||
> See the [Caching](../../optimization/cache) guide to speed up inference by storing and reusing intermediate outputs.
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## LoRA for faster inference
|
||||
|
||||
@@ -190,12 +190,6 @@ For detailed benchmark scripts and results, see [this gist](https://gist.github.
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImageLayeredPipeline
|
||||
|
||||
[[autodoc]] QwenImageLayeredPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
|
||||
@@ -332,4 +332,49 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
|
||||
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
</hfoptions>
|
||||
|
||||
## Dependencies
|
||||
|
||||
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
|
||||
|
||||
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import PipelineBlock
|
||||
|
||||
class MyCustomBlock(PipelineBlock):
|
||||
_requirements = {
|
||||
"transformers": ">=4.44.0",
|
||||
"sentencepiece": ">=0.2.0"
|
||||
}
|
||||
```
|
||||
|
||||
When there are blocks with different requirements, Diffusers merges their requirements.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
|
||||
class BlockA(PipelineBlock):
|
||||
_requirements = {"transformers": ">=4.44.0"}
|
||||
# ...
|
||||
|
||||
class BlockB(PipelineBlock):
|
||||
_requirements = {"sentencepiece": ">=0.2.0"}
|
||||
# ...
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict({
|
||||
"block_a": BlockA,
|
||||
"block_b": BlockB,
|
||||
})
|
||||
```
|
||||
|
||||
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
|
||||
|
||||
```md
|
||||
# missing package
|
||||
xyz-package was specified in the requirements but wasn't found in the current environment.
|
||||
|
||||
# version mismatch
|
||||
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
|
||||
```
|
||||
|
||||
@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
||||
# with open(CONFIG, "w") as f:
|
||||
# json.dump(automap, f)
|
||||
with open("requirements.txt", "w") as f:
|
||||
f.write("")
|
||||
|
||||
def _choose_block(self, candidates, chosen=None):
|
||||
for cls, base in candidates:
|
||||
|
||||
@@ -266,10 +266,6 @@ class _HubKernelConfig:
|
||||
function_attr: str
|
||||
revision: str | None = None
|
||||
kernel_fn: Callable | None = None
|
||||
wrapped_forward_attr: str | None = None
|
||||
wrapped_backward_attr: str | None = None
|
||||
wrapped_forward_fn: Callable | None = None
|
||||
wrapped_backward_fn: Callable | None = None
|
||||
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
@@ -284,11 +280,7 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# revision="fake-ops-return-probs",
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2",
|
||||
function_attr="flash_attn_func",
|
||||
revision=None,
|
||||
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
||||
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
),
|
||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||
@@ -613,39 +605,22 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
|
||||
|
||||
# ===== Helpers for downloading kernels =====
|
||||
def _resolve_kernel_attr(module, attr_path: str):
|
||||
target = module
|
||||
for attr in attr_path.split("."):
|
||||
if not hasattr(target, attr):
|
||||
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
|
||||
target = getattr(target, attr)
|
||||
return target
|
||||
|
||||
|
||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
if backend not in _HUB_KERNELS_REGISTRY:
|
||||
return
|
||||
config = _HUB_KERNELS_REGISTRY[backend]
|
||||
|
||||
needs_kernel = config.kernel_fn is None
|
||||
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
|
||||
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
|
||||
|
||||
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
|
||||
if config.kernel_fn is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||
if needs_kernel:
|
||||
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
|
||||
kernel_func = getattr(kernel_module, config.function_attr)
|
||||
|
||||
if needs_wrapped_forward:
|
||||
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
|
||||
|
||||
if needs_wrapped_backward:
|
||||
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
|
||||
# Cache the downloaded kernel function in the config object
|
||||
config.kernel_fn = kernel_func
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||
@@ -1096,237 +1071,6 @@ def _flash_attention_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
|
||||
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||
wrapped_forward_fn = config.wrapped_forward_fn
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_forward_fn is None or wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
|
||||
"for context parallel execution."
|
||||
)
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
||||
|
||||
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
out, lse, S_dmask, rng_state = wrapped_forward_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
|
||||
)
|
||||
|
||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||
|
||||
_ = wrapped_backward_fn(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.dropout_p,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state,
|
||||
)
|
||||
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_3_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
*,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: bool | None = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
|
||||
if dropout_p != 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx._hub_kernel = func
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_3_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: bool | None = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
):
|
||||
query, key, value = ctx.saved_tensors
|
||||
kernel_fn = ctx._hub_kernel
|
||||
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
|
||||
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
|
||||
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
|
||||
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
|
||||
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
|
||||
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
|
||||
with torch.enable_grad():
|
||||
query_r = query.detach().requires_grad_(True)
|
||||
key_r = key.detach().requires_grad_(True)
|
||||
value_r = value.detach().requires_grad_(True)
|
||||
|
||||
out = kernel_fn(
|
||||
q=query_r,
|
||||
k=key_r,
|
||||
v=value_r,
|
||||
softmax_scale=ctx.scale,
|
||||
causal=ctx.is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=False,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
|
||||
grad_query, grad_key, grad_value = torch.autograd.grad(
|
||||
out,
|
||||
(query_r, key_r, value_r),
|
||||
grad_out,
|
||||
retain_graph=False,
|
||||
allow_unused=False,
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _sage_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1365,46 +1109,6 @@ def _sage_attention_forward_op(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
||||
if dropout_p > 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
@@ -2261,7 +1965,7 @@ def _flash_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -2279,35 +1983,17 @@ def _flash_attention_hub(
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_flash_attention_hub_forward_op,
|
||||
backward_op=_flash_attention_hub_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@@ -2454,7 +2140,7 @@ def _flash_attention_3(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -2469,68 +2155,33 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
return_attn_probs=return_attn_probs,
|
||||
)
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
|
||||
forward_op = functools.partial(
|
||||
_flash_attention_3_hub_forward_op,
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
return_attn_probs=return_attn_probs,
|
||||
)
|
||||
backward_op = functools.partial(
|
||||
_flash_attention_3_hub_backward_op,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
)
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_attn_probs,
|
||||
forward_op=forward_op,
|
||||
backward_op=backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_attn_probs:
|
||||
out, lse = out
|
||||
return out, lse
|
||||
|
||||
return out
|
||||
# When `return_attn_probs` is True, the above returns a tuple of
|
||||
# actual outputs and lse.
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
@@ -3162,7 +2813,7 @@ def _sage_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -3190,23 +2841,6 @@ def _sage_attention_hub(
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_sage_attention_hub_forward_op,
|
||||
backward_op=_sage_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@@ -424,7 +424,7 @@ class Flux2SingleTransformerBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb_mod: torch.Tensor,
|
||||
temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
joint_attention_kwargs: dict[str, Any] | None = None,
|
||||
split_hidden_states: bool = False,
|
||||
@@ -436,7 +436,7 @@ class Flux2SingleTransformerBlock(nn.Module):
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
mod_shift, mod_scale, mod_gate = Flux2Modulation.split(temb_mod, 1)[0]
|
||||
mod_shift, mod_scale, mod_gate = temb_mod_params
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
|
||||
@@ -498,18 +498,16 @@ class Flux2TransformerBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb_mod_img: torch.Tensor,
|
||||
temb_mod_txt: torch.Tensor,
|
||||
temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
joint_attention_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Modulation parameters shape: [1, 1, self.dim]
|
||||
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = Flux2Modulation.split(temb_mod_img, 2)
|
||||
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = Flux2Modulation.split(
|
||||
temb_mod_txt, 2
|
||||
)
|
||||
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
|
||||
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
|
||||
|
||||
# Img stream
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
@@ -629,19 +627,15 @@ class Flux2Modulation(nn.Module):
|
||||
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, temb: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
||||
mod = self.act_fn(temb)
|
||||
mod = self.linear(mod)
|
||||
return mod
|
||||
|
||||
@staticmethod
|
||||
# split inside the transformer blocks, to avoid passing tuples into checkpoints https://github.com/huggingface/diffusers/issues/12776
|
||||
def split(mod: torch.Tensor, mod_param_sets: int) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
||||
if mod.ndim == 2:
|
||||
mod = mod.unsqueeze(1)
|
||||
mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1)
|
||||
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
|
||||
# Return tuple of 3-tuples of modulation params shift/scale/gate
|
||||
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets))
|
||||
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
|
||||
|
||||
|
||||
class Flux2Transformer2DModel(
|
||||
@@ -830,7 +824,7 @@ class Flux2Transformer2DModel(
|
||||
|
||||
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
||||
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)[0]
|
||||
|
||||
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
@@ -867,8 +861,8 @@ class Flux2Transformer2DModel(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb_mod_img=double_stream_mod_img,
|
||||
temb_mod_txt=double_stream_mod_txt,
|
||||
temb_mod_params_img=double_stream_mod_img,
|
||||
temb_mod_params_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
@@ -890,7 +884,7 @@ class Flux2Transformer2DModel(
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
temb_mod=single_stream_mod,
|
||||
temb_mod_params=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
@@ -29,16 +28,10 @@ from tqdm.auto import tqdm
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..pipelines.pipeline_loading_utils import (
|
||||
LOADABLE_CLASSES,
|
||||
_fetch_class_library_tuple,
|
||||
_unwrap_model,
|
||||
simple_get_class_obj,
|
||||
)
|
||||
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
|
||||
from ..utils import PushToHubMixin, is_accelerate_available, logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
from .components_manager import ComponentsManager
|
||||
from .modular_pipeline_utils import (
|
||||
MODULAR_MODEL_CARD_TEMPLATE,
|
||||
@@ -47,6 +40,7 @@ from .modular_pipeline_utils import (
|
||||
InputParam,
|
||||
InsertableDict,
|
||||
OutputParam,
|
||||
_validate_requirements,
|
||||
combine_inputs,
|
||||
combine_outputs,
|
||||
format_components,
|
||||
@@ -297,6 +291,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
_requirements: dict[str, str] | None = None
|
||||
_workflow_map = None
|
||||
|
||||
@classmethod
|
||||
@@ -411,6 +406,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
if "requirements" in config and config["requirements"] is not None:
|
||||
_ = _validate_requirements(config["requirements"])
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
@@ -435,8 +433,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||
|
||||
self.register_to_config(auto_map=auto_map)
|
||||
|
||||
# resolve requirements
|
||||
requirements = _validate_requirements(getattr(self, "_requirements", None))
|
||||
if requirements:
|
||||
self.register_to_config(requirements=requirements)
|
||||
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
config = dict(self.config)
|
||||
self._internal_dict = FrozenDict(config)
|
||||
@@ -1247,6 +1250,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _requirements(self) -> dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
@@ -1826,84 +1837,17 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
return pipeline
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str | os.PathLike,
|
||||
safe_serialization: bool = True,
|
||||
variant: str | None = None,
|
||||
max_shard_size: int | str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save the pipeline and all its components to a directory, so that it can be re-loaded using the
|
||||
[`~ModularPipeline.from_pretrained`] class method.
|
||||
Save the pipeline to a directory. It does not save components, you need to save them separately.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save the pipeline to. Will be created if it doesn't exist.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||
max_shard_size (`int` or `str`, defaults to `None`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
||||
If expressed as an integer, the unit is bytes.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the pipeline to the Hugging Face model hub after saving it.
|
||||
**kwargs: Additional keyword arguments passed along to the push to hub method.
|
||||
Path to the directory where the pipeline will be saved.
|
||||
push_to_hub (`bool`, optional):
|
||||
Whether to push the pipeline to the huggingface hub.
|
||||
**kwargs: Additional arguments passed to `save_config()` method
|
||||
"""
|
||||
for component_name, component_spec in self._component_specs.items():
|
||||
sub_model = getattr(self, component_name, None)
|
||||
if sub_model is None:
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
if is_compiled_module(sub_model):
|
||||
sub_model = _unwrap_model(sub_model)
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
if library_name in sys.modules:
|
||||
library = importlib.import_module(library_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class, None)
|
||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||
save_method_name = save_load_methods[0]
|
||||
break
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
if save_method_name is None:
|
||||
logger.warning(f"self.{component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
||||
continue
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
save_method_signature = inspect.signature(save_method)
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
save_kwargs["safe_serialization"] = safe_serialization
|
||||
if save_method_accept_variant:
|
||||
save_kwargs["variant"] = variant
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
|
||||
save_method(os.path.join(save_directory, component_name), **save_kwargs)
|
||||
|
||||
self.save_config(save_directory=save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", None)
|
||||
@@ -1912,7 +1856,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
# Generate modular pipeline card content
|
||||
card_content = generate_modular_model_card_content(self.blocks)
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id,
|
||||
token=token,
|
||||
@@ -1921,8 +1868,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
is_modular=True,
|
||||
)
|
||||
model_card = populate_model_card(model_card, tags=card_content["tags"])
|
||||
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
|
||||
self.save_config(save_directory=save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
|
||||
@@ -22,10 +22,12 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||
from ..utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -972,6 +974,89 @@ def make_doc_string(
|
||||
return output
|
||||
|
||||
|
||||
def _validate_requirements(reqs):
|
||||
if reqs is None:
|
||||
normalized_reqs = {}
|
||||
else:
|
||||
if not isinstance(reqs, dict):
|
||||
raise ValueError(
|
||||
"Requirements must be provided as a dictionary mapping package names to version specifiers."
|
||||
)
|
||||
normalized_reqs = _normalize_requirements(reqs)
|
||||
|
||||
if not normalized_reqs:
|
||||
return {}
|
||||
|
||||
final: dict[str, str] = {}
|
||||
for req, specified_ver in normalized_reqs.items():
|
||||
req_available, req_actual_ver = _is_package_available(req)
|
||||
if not req_available:
|
||||
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
|
||||
|
||||
if specified_ver:
|
||||
try:
|
||||
specifier = SpecifierSet(specified_ver)
|
||||
except InvalidSpecifier as err:
|
||||
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
|
||||
|
||||
if req_actual_ver == "N/A":
|
||||
logger.warning(
|
||||
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
|
||||
)
|
||||
elif not specifier.contains(req_actual_ver, prereleases=True):
|
||||
logger.warning(
|
||||
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
|
||||
)
|
||||
|
||||
final[req] = specified_ver
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def _normalize_requirements(reqs):
|
||||
if not reqs:
|
||||
return {}
|
||||
|
||||
normalized: "OrderedDict[str, str]" = OrderedDict()
|
||||
|
||||
def _accumulate(mapping: dict[str, Any]):
|
||||
for pkg, spec in mapping.items():
|
||||
if isinstance(spec, dict):
|
||||
# This is recursive because blocks are composable. This way, we can merge requirements
|
||||
# from multiple blocks.
|
||||
_accumulate(spec)
|
||||
continue
|
||||
|
||||
pkg_name = str(pkg).strip()
|
||||
if not pkg_name:
|
||||
raise ValueError("Requirement package name cannot be empty.")
|
||||
|
||||
spec_str = "" if spec is None else str(spec).strip()
|
||||
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
|
||||
spec_str = f"=={spec_str}"
|
||||
|
||||
existing_spec = normalized.get(pkg_name)
|
||||
if existing_spec is not None:
|
||||
if not existing_spec and spec_str:
|
||||
normalized[pkg_name] = spec_str
|
||||
elif existing_spec and spec_str and existing_spec != spec_str:
|
||||
try:
|
||||
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
|
||||
except InvalidSpecifier:
|
||||
logger.warning(
|
||||
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
|
||||
)
|
||||
else:
|
||||
normalized[pkg_name] = str(combined_spec)
|
||||
continue
|
||||
|
||||
normalized[pkg_name] = spec_str
|
||||
|
||||
_accumulate(reqs)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
|
||||
"""
|
||||
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
|
||||
|
||||
@@ -18,6 +18,7 @@ import re
|
||||
import urllib.parse as ul
|
||||
from typing import Callable
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -33,13 +34,13 @@ from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_ftfy_available, logging, replace_example_docstring
|
||||
from diffusers.utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
ASPECT_RATIO_256_BIN = {
|
||||
|
||||
@@ -516,9 +516,6 @@ def dequantize_gguf_tensor(tensor):
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
|
||||
# Conver to plain tensor to avoid unnecessary __torch_function__ overhead.
|
||||
tensor = tensor.as_tensor()
|
||||
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
|
||||
|
||||
@@ -528,7 +525,7 @@ def dequantize_gguf_tensor(tensor):
|
||||
dequant = dequant_fn(blocks, block_size, type_size)
|
||||
dequant = dequant.reshape(shape)
|
||||
|
||||
return dequant
|
||||
return dequant.as_tensor()
|
||||
|
||||
|
||||
class GGUFParameter(torch.nn.Parameter):
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -42,7 +41,7 @@ class FlowMatchLCMSchedulerOutput(BaseOutput):
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.Tensor
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -80,11 +79,11 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_beta_sigmas (`bool`, defaults to False):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
|
||||
time_shift_type (`str`, defaults to "exponential"):
|
||||
The type of dynamic resolution-dependent timestep shifting to apply.
|
||||
scale_factors (`list[float]`, *optional*, defaults to `None`):
|
||||
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
|
||||
scale_factors ('list', defaults to None)
|
||||
It defines how to scale the latents at which predictions are made.
|
||||
upscale_mode (`str`, *optional*, defaults to "bicubic"):
|
||||
Upscaling method, applied if scale-wise generation is considered.
|
||||
upscale_mode ('str', defaults to 'bicubic')
|
||||
Upscaling method, applied if scale-wise generation is considered
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -102,33 +101,16 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
max_image_seq_len: int = 4096,
|
||||
invert_sigmas: bool = False,
|
||||
shift_terminal: float | None = None,
|
||||
use_karras_sigmas: bool | None = False,
|
||||
use_exponential_sigmas: bool | None = False,
|
||||
use_beta_sigmas: bool | None = False,
|
||||
time_shift_type: Literal["exponential", "linear"] = "exponential",
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
scale_factors: list[float] | None = None,
|
||||
upscale_mode: Literal[
|
||||
"nearest",
|
||||
"linear",
|
||||
"bilinear",
|
||||
"bicubic",
|
||||
"trilinear",
|
||||
"area",
|
||||
"nearest-exact",
|
||||
] = "bicubic",
|
||||
upscale_mode: str = "bicubic",
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
@@ -180,7 +162,7 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -190,18 +172,18 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_shift(self, shift: float) -> None:
|
||||
def set_shift(self, shift: float):
|
||||
self._shift = shift
|
||||
|
||||
def set_scale_factors(self, scale_factors: list[float], upscale_mode: str) -> None:
|
||||
def set_scale_factors(self, scale_factors: list, upscale_mode):
|
||||
"""
|
||||
Sets scale factors for a scale-wise generation regime.
|
||||
|
||||
Args:
|
||||
scale_factors (`list[float]`):
|
||||
The scale factors for each step.
|
||||
scale_factors (`list`):
|
||||
The scale factors for each step
|
||||
upscale_mode (`str`):
|
||||
Upscaling method.
|
||||
Upscaling method
|
||||
"""
|
||||
self._scale_factors = scale_factors
|
||||
self._upscale_mode = upscale_mode
|
||||
@@ -256,18 +238,16 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma: float | torch.FloatTensor) -> float | torch.FloatTensor:
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def time_shift(
|
||||
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
|
||||
) -> float | np.ndarray | torch.Tensor:
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
return self._time_shift_linear(mu, sigma, t)
|
||||
|
||||
def stretch_shift_to_terminal(self, t: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
|
||||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
||||
value.
|
||||
@@ -276,13 +256,12 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor` or `np.ndarray`):
|
||||
A tensor or numpy array of timesteps to be stretched and shifted.
|
||||
t (`torch.Tensor`):
|
||||
A tensor of timesteps to be stretched and shifted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` or `np.ndarray`:
|
||||
A tensor or numpy array of adjusted timesteps such that the final value equals
|
||||
`self.config.shift_terminal`.
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
@@ -291,12 +270,12 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
num_inference_steps: int = None,
|
||||
device: str | torch.device = None,
|
||||
sigmas: list[float] | None = None,
|
||||
mu: float | None = None,
|
||||
mu: float = None,
|
||||
timesteps: list[float] | None = None,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -338,45 +317,43 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
is_timesteps_provided = timesteps is not None
|
||||
|
||||
if is_timesteps_provided:
|
||||
timesteps = np.array(timesteps).astype(np.float32) # type: ignore
|
||||
timesteps = np.array(timesteps).astype(np.float32)
|
||||
|
||||
if sigmas is None:
|
||||
if timesteps is None:
|
||||
timesteps = np.linspace( # type: ignore
|
||||
self._sigma_to_t(self.sigma_max),
|
||||
self._sigma_to_t(self.sigma_min),
|
||||
num_inference_steps,
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
sigmas = timesteps / self.config.num_train_timesteps # type: ignore
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
else:
|
||||
sigmas = np.array(sigmas).astype(np.float32) # type: ignore
|
||||
sigmas = np.array(sigmas).astype(np.float32)
|
||||
num_inference_steps = len(sigmas)
|
||||
|
||||
# 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
|
||||
# "exponential" or "linear" type is applied
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas) # type: ignore
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) # type: ignore
|
||||
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
||||
|
||||
# 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas) # type: ignore
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
|
||||
# 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
# 5. Convert sigmas and timesteps to tensors and move to specified device
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) # type: ignore
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
if not is_timesteps_provided:
|
||||
timesteps = sigmas * self.config.num_train_timesteps # type: ignore
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
else:
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) # type: ignore
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
||||
|
||||
# 6. Append the terminal sigma value.
|
||||
# If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
|
||||
@@ -393,11 +370,7 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: float | torch.Tensor,
|
||||
schedule_timesteps: torch.Tensor | None = None,
|
||||
) -> int:
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -409,9 +382,9 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return int(indices[pos].item())
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep: float | torch.Tensor) -> None:
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
@@ -486,12 +459,7 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
size = [round(self._scale_factors[self._step_index] * size) for size in self._init_size]
|
||||
x0_pred = torch.nn.functional.interpolate(x0_pred, size=size, mode=self._upscale_mode)
|
||||
|
||||
noise = randn_tensor(
|
||||
x0_pred.shape,
|
||||
generator=generator,
|
||||
device=x0_pred.device,
|
||||
dtype=x0_pred.dtype,
|
||||
)
|
||||
noise = randn_tensor(x0_pred.shape, generator=generator, device=x0_pred.device, dtype=x0_pred.dtype)
|
||||
prev_sample = (1 - sigma_next) * x0_pred + sigma_next * noise
|
||||
|
||||
# upon completion increase step index by one
|
||||
@@ -505,7 +473,7 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return FlowMatchLCMSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
@@ -626,15 +594,11 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def _time_shift_exponential(
|
||||
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
|
||||
) -> float | np.ndarray | torch.Tensor:
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
def _time_shift_linear(
|
||||
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
|
||||
) -> float | np.ndarray | torch.Tensor:
|
||||
def _time_shift_linear(self, mu, sigma, t):
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -375,7 +375,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# additionally check if dynamic compilation works.
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
@@ -390,7 +390,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
with torch.inference_mode():
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
@@ -8,6 +10,7 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -17,7 +20,13 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
@@ -400,6 +409,56 @@ class ModularGuiderTesterMixin:
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def test_custom_requirements_save_load(self):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "modular_config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_warnings(self):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
|
||||
@@ -10,11 +11,17 @@ from diffusers.models import AutoencoderDC, AutoencoderKL
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">", "4.57.1"),
|
||||
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
|
||||
strict=False,
|
||||
)
|
||||
class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = PRXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
|
||||
Reference in New Issue
Block a user