mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 00:14:23 +08:00
Compare commits
1 Commits
kernelize
...
integratio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07798babac |
65
src/diffusers/hooks/_cfg_parallel.py
Normal file
65
src/diffusers/hooks/_cfg_parallel.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from ..utils import get_logger
|
||||||
|
from ._common import _BATCHED_INPUT_IDENTIFIERS
|
||||||
|
from .hooks import HookRegistry, ModelHook
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
_CFG_PARALLEL = "cfg_parallel"
|
||||||
|
|
||||||
|
|
||||||
|
class CFGParallelHook(ModelHook):
|
||||||
|
def initialize_hook(self, module):
|
||||||
|
if not dist.is_initialized():
|
||||||
|
raise RuntimeError("Distributed environment not initialized.")
|
||||||
|
return module
|
||||||
|
|
||||||
|
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||||
|
if len(args) > 0:
|
||||||
|
logger.warning(
|
||||||
|
"CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution."
|
||||||
|
)
|
||||||
|
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
assert world_size == 2, "This is an example hook designed to only work with 2 processes."
|
||||||
|
|
||||||
|
for key in list(kwargs.keys()):
|
||||||
|
if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None:
|
||||||
|
continue
|
||||||
|
kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous()
|
||||||
|
|
||||||
|
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||||
|
sample = output[0]
|
||||||
|
sample_list = [torch.empty_like(sample) for _ in range(world_size)]
|
||||||
|
dist.all_gather(sample_list, sample)
|
||||||
|
sample = torch.cat(sample_list, dim=0).contiguous()
|
||||||
|
|
||||||
|
return_dict = kwargs.get("return_dict", False)
|
||||||
|
if not return_dict:
|
||||||
|
return (sample, *output[1:])
|
||||||
|
return output.__class__(sample, *output[1:])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_cfg_parallel(module: torch.nn.Module) -> None:
|
||||||
|
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||||
|
hook = CFGParallelHook()
|
||||||
|
registry.register_hook(hook, _CFG_PARALLEL)
|
||||||
26
src/diffusers/hooks/_common.py
Normal file
26
src/diffusers/hooks/_common.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from ..models.attention_processor import Attention, MochiAttention
|
||||||
|
|
||||||
|
|
||||||
|
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||||
|
|
||||||
|
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||||
|
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||||
|
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||||
|
|
||||||
|
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||||
|
{
|
||||||
|
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_BATCHED_INPUT_IDENTIFIERS = (
|
||||||
|
"hidden_states",
|
||||||
|
"encoder_hidden_states",
|
||||||
|
"pooled_projections",
|
||||||
|
"timestep",
|
||||||
|
"attention_mask",
|
||||||
|
"encoder_attention_mask",
|
||||||
|
"guidance",
|
||||||
|
)
|
||||||
@@ -20,19 +20,18 @@ import torch
|
|||||||
|
|
||||||
from ..models.attention_processor import Attention, MochiAttention
|
from ..models.attention_processor import Attention, MochiAttention
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from ._common import (
|
||||||
|
_ATTENTION_CLASSES,
|
||||||
|
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
)
|
||||||
from .hooks import HookRegistry, ModelHook
|
from .hooks import HookRegistry, ModelHook
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
|
||||||
|
|
||||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
|
||||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
|
||||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PyramidAttentionBroadcastConfig:
|
class PyramidAttentionBroadcastConfig:
|
||||||
r"""
|
r"""
|
||||||
@@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
|
|||||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||||
|
|
||||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||||
|
|
||||||
current_timestep_callback: Callable[[], int] = None
|
current_timestep_callback: Callable[[], int] = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user