Compare commits

..

12 Commits

Author SHA1 Message Date
sayakpaul
5fa3204fe0 up 2026-01-02 21:38:10 +05:30
Sayak Paul
75c61d1a68 Merge branch 'main' into cp-fixes-attn-backends 2026-01-02 21:31:11 +05:30
Maxim Balabanski
208cda8f6d fix Qwen Image Transformer single file loading mapping function to be consistent with other loader APIs (#12894)
fix Qwen single file loading to be consistent with other loader API
2026-01-02 12:59:11 +05:30
sayakpaul
af7c9a4817 address PR feedback. 2025-12-30 14:12:57 +05:30
Sayak Paul
2fe5bec206 Merge branch 'main' into cp-fixes-attn-backends 2025-12-30 13:45:39 +05:30
Sayak Paul
6b5b3f705e Merge branch 'main' into cp-fixes-attn-backends 2025-12-21 22:06:57 +05:30
Sayak Paul
301c223318 Merge branch 'main' into cp-fixes-attn-backends 2025-12-18 14:21:52 +08:00
Sayak Paul
3b1ccd79a5 Merge branch 'main' into cp-fixes-attn-backends 2025-12-15 20:30:22 +08:00
sayakpaul
0c35ed4708 up 2025-12-12 15:26:43 +05:30
sayakpaul
738f278d93 gracefully error out when attn-backend x cp combo isn't supported. 2025-12-12 15:25:59 +05:30
sayakpaul
23251d6cf6 Revert "gracefully error out when attn-backend x cp combo isn't supported."
This reverts commit c8abb5d7c0.
2025-12-12 15:24:09 +05:30
sayakpaul
c8abb5d7c0 gracefully error out when attn-backend x cp combo isn't supported. 2025-12-12 15:20:18 +05:30
4 changed files with 34 additions and 7 deletions

View File

@@ -263,8 +263,8 @@ def main():
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)

View File

@@ -162,7 +162,7 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda x: x,
"checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
"default_subfolder": "transformer",
},
"Flux2Transformer2DModel": {

View File

@@ -235,6 +235,10 @@ class _AttentionBackendRegistry:
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]
@classmethod
def set_active_backend(cls, backend: str):
cls._active_backend = backend
@classmethod
def list_backends(cls):
return list(cls._backends.keys())
@@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
_maybe_download_kernel_for_backend(backend)
old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend
_AttentionBackendRegistry.set_active_backend(backend)
try:
yield
finally:
_AttentionBackendRegistry._active_backend = old_backend
_AttentionBackendRegistry.set_active_backend(old_backend)
def dispatch_attention_fn(
@@ -348,6 +352,7 @@ def dispatch_attention_fn(
check(**kwargs)
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
return backend_fn(**kwargs)

View File

@@ -602,6 +602,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .attention import AttentionModuleMixin
from .attention_dispatch import (
AttentionBackendName,
_AttentionBackendRegistry,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)
@@ -610,6 +611,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
parallel_config_set = False
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if getattr(processor, "_parallel_config", None) is not None:
parallel_config_set = True
break
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
@@ -617,10 +628,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend):
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
raise ValueError(
f"Context parallelism is enabled but backend '{backend.value}' "
f"which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `model.enable_parallelism()`."
)
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
@@ -629,6 +648,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
continue
processor._attention_backend = backend
# Important to set the active backend so that it propagates gracefully throughout.
_AttentionBackendRegistry.set_active_backend(backend)
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
@@ -1541,7 +1563,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `enable_parallelism()`."
f"calling `model.enable_parallelism()`."
)
# All modules use the same attention processor and backend. We don't need to