mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 11:24:43 +08:00
Compare commits
6 Commits
qwenimage-
...
cp-fixes-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
301c223318 | ||
|
|
3b1ccd79a5 | ||
|
|
0c35ed4708 | ||
|
|
738f278d93 | ||
|
|
23251d6cf6 | ||
|
|
c8abb5d7c0 |
@@ -235,6 +235,10 @@ class _AttentionBackendRegistry:
|
|||||||
def get_active_backend(cls):
|
def get_active_backend(cls):
|
||||||
return cls._active_backend, cls._backends[cls._active_backend]
|
return cls._active_backend, cls._backends[cls._active_backend]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_active_backend(cls, backend: str):
|
||||||
|
cls._active_backend = backend
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_backends(cls):
|
def list_backends(cls):
|
||||||
return list(cls._backends.keys())
|
return list(cls._backends.keys())
|
||||||
@@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
|
|||||||
_maybe_download_kernel_for_backend(backend)
|
_maybe_download_kernel_for_backend(backend)
|
||||||
|
|
||||||
old_backend = _AttentionBackendRegistry._active_backend
|
old_backend = _AttentionBackendRegistry._active_backend
|
||||||
_AttentionBackendRegistry._active_backend = backend
|
_AttentionBackendRegistry.set_active_backend(backend)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
_AttentionBackendRegistry._active_backend = old_backend
|
_AttentionBackendRegistry.set_active_backend(old_backend)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_attention_fn(
|
def dispatch_attention_fn(
|
||||||
@@ -348,6 +352,18 @@ def dispatch_attention_fn(
|
|||||||
check(**kwargs)
|
check(**kwargs)
|
||||||
|
|
||||||
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
|
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
|
||||||
|
|
||||||
|
if "_parallel_config" in kwargs and kwargs["_parallel_config"] is not None:
|
||||||
|
attention_backend = AttentionBackendName(backend_name)
|
||||||
|
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
|
||||||
|
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
|
||||||
|
raise ValueError(
|
||||||
|
f"Context parallelism is enabled but backend '{attention_backend.value}' "
|
||||||
|
f"which does not support context parallelism. "
|
||||||
|
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||||
|
f"calling `model.enable_parallelism()`."
|
||||||
|
)
|
||||||
|
|
||||||
return backend_fn(**kwargs)
|
return backend_fn(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -602,6 +602,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
from .attention import AttentionModuleMixin
|
from .attention import AttentionModuleMixin
|
||||||
from .attention_dispatch import (
|
from .attention_dispatch import (
|
||||||
AttentionBackendName,
|
AttentionBackendName,
|
||||||
|
_AttentionBackendRegistry,
|
||||||
_check_attention_backend_requirements,
|
_check_attention_backend_requirements,
|
||||||
_maybe_download_kernel_for_backend,
|
_maybe_download_kernel_for_backend,
|
||||||
)
|
)
|
||||||
@@ -629,6 +630,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
continue
|
continue
|
||||||
processor._attention_backend = backend
|
processor._attention_backend = backend
|
||||||
|
|
||||||
|
# Important to set the active backend so that it propagates gracefully throughout.
|
||||||
|
_AttentionBackendRegistry.set_active_backend(backend)
|
||||||
|
|
||||||
def reset_attention_backend(self) -> None:
|
def reset_attention_backend(self) -> None:
|
||||||
"""
|
"""
|
||||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
|
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
|
||||||
@@ -1541,7 +1545,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
||||||
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
||||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||||
f"calling `enable_parallelism()`."
|
f"calling `model.enable_parallelism()`."
|
||||||
)
|
)
|
||||||
|
|
||||||
# All modules use the same attention processor and backend. We don't need to
|
# All modules use the same attention processor and backend. We don't need to
|
||||||
|
|||||||
Reference in New Issue
Block a user