Compare commits

...

1 Commits

Author SHA1 Message Date
Aryan
07798babac update 2025-02-23 15:24:24 +01:00
3 changed files with 100 additions and 10 deletions

View File

@@ -0,0 +1,65 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
from ..utils import get_logger
from ._common import _BATCHED_INPUT_IDENTIFIERS
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_CFG_PARALLEL = "cfg_parallel"
class CFGParallelHook(ModelHook):
def initialize_hook(self, module):
if not dist.is_initialized():
raise RuntimeError("Distributed environment not initialized.")
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if len(args) > 0:
logger.warning(
"CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution."
)
world_size = dist.get_world_size()
rank = dist.get_rank()
assert world_size == 2, "This is an example hook designed to only work with 2 processes."
for key in list(kwargs.keys()):
if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None:
continue
kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous()
output = self.fn_ref.original_forward(*args, **kwargs)
sample = output[0]
sample_list = [torch.empty_like(sample) for _ in range(world_size)]
dist.all_gather(sample_list, sample)
sample = torch.cat(sample_list, dim=0).contiguous()
return_dict = kwargs.get("return_dict", False)
if not return_dict:
return (sample, *output[1:])
return output.__class__(sample, *output[1:])
def apply_cfg_parallel(module: torch.nn.Module) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = CFGParallelHook()
registry.register_hook(hook, _CFG_PARALLEL)

View File

@@ -0,0 +1,26 @@
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
_BATCHED_INPUT_IDENTIFIERS = (
"hidden_states",
"encoder_hidden_states",
"pooled_projections",
"timestep",
"attention_mask",
"encoder_attention_mask",
"guidance",
)

View File

@@ -20,19 +20,18 @@ import torch
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from ._common import (
_ATTENTION_CLASSES,
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
)
from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
class PyramidAttentionBroadcastConfig:
r"""
@@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None