mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 06:24:19 +08:00
Compare commits
1 Commits
wan-cache
...
integratio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
98771d3611 |
30
src/diffusers/hooks/_common.py
Normal file
30
src/diffusers/hooks/_common.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ..models.attention_processor import Attention, MochiAttention
|
||||||
|
|
||||||
|
|
||||||
|
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||||
|
|
||||||
|
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||||
|
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||||
|
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||||
|
|
||||||
|
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||||
|
{
|
||||||
|
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
}
|
||||||
|
)
|
||||||
262
src/diffusers/hooks/first_block_cache.py
Normal file
262
src/diffusers/hooks/first_block_cache.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..utils import get_logger
|
||||||
|
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||||
|
from .hooks import HookRegistry, ModelHook
|
||||||
|
from .utils import _extract_return_information
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
||||||
|
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FirstBlockCacheConfig:
|
||||||
|
r"""
|
||||||
|
Configuration for [First Block
|
||||||
|
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold (`float`, defaults to `0.05`):
|
||||||
|
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
||||||
|
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
|
||||||
|
poorer generation quality. A lower threshold may not result in significant generation speedup. The
|
||||||
|
threshold is compared against the absmean difference of the residuals between the current and cached
|
||||||
|
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
|
||||||
|
skipped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
threshold: float = 0.05
|
||||||
|
|
||||||
|
|
||||||
|
class FBCSharedBlockState:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||||
|
self.head_block_residual: torch.Tensor = None
|
||||||
|
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||||
|
self.should_compute: bool = True
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.tail_block_residuals = None
|
||||||
|
self.should_compute = True
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"FirstBlockCacheSharedState(cache={self.cache})"
|
||||||
|
|
||||||
|
|
||||||
|
class FBCHeadBlockHook(ModelHook):
|
||||||
|
_is_stateful = True
|
||||||
|
|
||||||
|
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
|
||||||
|
self.shared_state = shared_state
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
inputs = inspect.signature(module.__class__.forward)
|
||||||
|
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
|
||||||
|
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
|
||||||
|
|
||||||
|
try:
|
||||||
|
outputs = _extract_return_information(module.__class__.forward)
|
||||||
|
outputs_index_to_str = dict(enumerate(outputs))
|
||||||
|
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
|
||||||
|
except RuntimeError:
|
||||||
|
logger.error(f"Failed to extract return information for {module.__class__}")
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
|
||||||
|
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
|
||||||
|
f"in order for us to add support for this module."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._inputs_index_to_str = inputs_index_to_str
|
||||||
|
self._inputs_str_to_index = inputs_str_to_index
|
||||||
|
self._outputs_index_to_str = outputs_index_to_str
|
||||||
|
self._outputs_str_to_index = outputs_str_to_index
|
||||||
|
return module
|
||||||
|
|
||||||
|
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||||
|
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
|
||||||
|
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
|
||||||
|
original_hs = kwargs.get("hidden_states", None)
|
||||||
|
original_ehs = kwargs.get("encoder_hidden_states", None)
|
||||||
|
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
|
||||||
|
if ehs_input_idx is not None:
|
||||||
|
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
|
||||||
|
|
||||||
|
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
|
||||||
|
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
|
||||||
|
assert (ehs_input_idx is None) == (ehs_output_idx is None)
|
||||||
|
|
||||||
|
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||||
|
|
||||||
|
hs_residual = None
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
hs_residual = output[hs_output_idx] - original_hs
|
||||||
|
else:
|
||||||
|
hs_residual = output - original_hs
|
||||||
|
|
||||||
|
should_compute = self._should_compute_remaining_blocks(hs_residual)
|
||||||
|
self.shared_state.should_compute = should_compute
|
||||||
|
|
||||||
|
hs, ehs = None, None
|
||||||
|
if not should_compute:
|
||||||
|
# Apply caching
|
||||||
|
logger.info("Skipping forward pass through remaining blocks")
|
||||||
|
hs = self.shared_state.tail_block_residuals[0] + output[hs_output_idx]
|
||||||
|
if ehs_output_idx is not None:
|
||||||
|
ehs = self.shared_state.tail_block_residuals[1] + output[ehs_output_idx]
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
return_output = [None] * len(output)
|
||||||
|
return_output[hs_output_idx] = hs
|
||||||
|
return_output[ehs_output_idx] = ehs
|
||||||
|
return_output = tuple(return_output)
|
||||||
|
else:
|
||||||
|
return_output = hs
|
||||||
|
return return_output
|
||||||
|
else:
|
||||||
|
logger.info("Computing forward pass through remaining blocks")
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
head_block_output = [None] * len(output)
|
||||||
|
head_block_output[0] = output[hs_output_idx]
|
||||||
|
head_block_output[1] = output[ehs_output_idx]
|
||||||
|
else:
|
||||||
|
head_block_output = output
|
||||||
|
self.shared_state.head_block_output = head_block_output
|
||||||
|
self.shared_state.head_block_residual = hs_residual
|
||||||
|
return output
|
||||||
|
|
||||||
|
def reset_state(self, module):
|
||||||
|
self.shared_state.reset()
|
||||||
|
return module
|
||||||
|
|
||||||
|
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
|
||||||
|
if self.shared_state.head_block_residual is None:
|
||||||
|
return True
|
||||||
|
prev_hs_residual = self.shared_state.head_block_residual
|
||||||
|
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
|
||||||
|
prev_hs_mean = prev_hs_residual.abs().mean()
|
||||||
|
diff = (hs_absmean / prev_hs_mean).item()
|
||||||
|
logger.info(f"Diff: {diff}, Threshold: {self.threshold}")
|
||||||
|
return diff > self.threshold
|
||||||
|
|
||||||
|
|
||||||
|
class FBCBlockHook(ModelHook):
|
||||||
|
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.shared_state = shared_state
|
||||||
|
self.is_tail = is_tail
|
||||||
|
|
||||||
|
def initialize_hook(self, module):
|
||||||
|
inputs = inspect.signature(module.__class__.forward)
|
||||||
|
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
|
||||||
|
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
|
||||||
|
|
||||||
|
try:
|
||||||
|
outputs = _extract_return_information(module.__class__.forward)
|
||||||
|
outputs_index_to_str = dict(enumerate(outputs))
|
||||||
|
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
|
||||||
|
except RuntimeError:
|
||||||
|
logger.error(f"Failed to extract return information for {module.__class__}")
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
|
||||||
|
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
|
||||||
|
f"in order for us to add support for this module."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._inputs_index_to_str = inputs_index_to_str
|
||||||
|
self._inputs_str_to_index = inputs_str_to_index
|
||||||
|
self._outputs_index_to_str = outputs_index_to_str
|
||||||
|
self._outputs_str_to_index = outputs_str_to_index
|
||||||
|
return module
|
||||||
|
|
||||||
|
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||||
|
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
|
||||||
|
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
|
||||||
|
original_hs = kwargs.get("hidden_states", None)
|
||||||
|
original_ehs = kwargs.get("encoder_hidden_states", None)
|
||||||
|
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
|
||||||
|
if ehs_input_idx is not None:
|
||||||
|
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
|
||||||
|
|
||||||
|
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
|
||||||
|
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
|
||||||
|
assert (ehs_input_idx is None) == (ehs_output_idx is None)
|
||||||
|
|
||||||
|
if self.shared_state.should_compute:
|
||||||
|
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||||
|
if self.is_tail:
|
||||||
|
hs_residual, ehs_residual = None, None
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
hs_residual = output[hs_output_idx] - self.shared_state.head_block_output[0]
|
||||||
|
ehs_residual = output[ehs_output_idx] - self.shared_state.head_block_output[1]
|
||||||
|
else:
|
||||||
|
hs_residual = output - self.shared_state.head_block_output
|
||||||
|
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
|
||||||
|
return output
|
||||||
|
|
||||||
|
output_count = len(self._outputs_index_to_str.keys())
|
||||||
|
return_output = [None] * output_count if output_count > 1 else original_hs
|
||||||
|
if output_count == 1:
|
||||||
|
return_output = original_hs
|
||||||
|
else:
|
||||||
|
return_output[hs_output_idx] = original_hs
|
||||||
|
return_output[ehs_output_idx] = original_ehs
|
||||||
|
return return_output
|
||||||
|
|
||||||
|
|
||||||
|
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||||
|
shared_state = FBCSharedBlockState()
|
||||||
|
remaining_blocks = []
|
||||||
|
|
||||||
|
for name, submodule in module.named_children():
|
||||||
|
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||||
|
continue
|
||||||
|
for block in submodule:
|
||||||
|
remaining_blocks.append((name, block))
|
||||||
|
|
||||||
|
head_block_name, head_block = remaining_blocks.pop(0)
|
||||||
|
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||||
|
|
||||||
|
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
|
||||||
|
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
|
||||||
|
|
||||||
|
for name, block in remaining_blocks:
|
||||||
|
logger.debug(f"Apply FBCBlockHook to '{name}'")
|
||||||
|
apply_fbc_block_hook(block, shared_state)
|
||||||
|
|
||||||
|
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
|
||||||
|
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
|
||||||
|
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||||
|
hook = FBCHeadBlockHook(state, threshold)
|
||||||
|
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
|
||||||
|
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||||
|
hook = FBCBlockHook(state, is_tail)
|
||||||
|
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||||
@@ -20,19 +20,18 @@ import torch
|
|||||||
|
|
||||||
from ..models.attention_processor import Attention, MochiAttention
|
from ..models.attention_processor import Attention, MochiAttention
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from ._common import (
|
||||||
|
_ATTENTION_CLASSES,
|
||||||
|
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
)
|
||||||
from .hooks import HookRegistry, ModelHook
|
from .hooks import HookRegistry, ModelHook
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
|
||||||
|
|
||||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
|
||||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
|
||||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PyramidAttentionBroadcastConfig:
|
class PyramidAttentionBroadcastConfig:
|
||||||
r"""
|
r"""
|
||||||
@@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
|
|||||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||||
|
|
||||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||||
|
|
||||||
current_timestep_callback: Callable[[], int] = None
|
current_timestep_callback: Callable[[], int] = None
|
||||||
|
|
||||||
|
|||||||
59
src/diffusers/hooks/utils.py
Normal file
59
src/diffusers/hooks/utils.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import ast
|
||||||
|
import inspect
|
||||||
|
import textwrap
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_return_information(func) -> List[str]:
|
||||||
|
"""Extracts return variable names in order from a function."""
|
||||||
|
try:
|
||||||
|
source = inspect.getsource(func)
|
||||||
|
source = textwrap.dedent(source) # Modify indentation to make parsing compatible
|
||||||
|
except (OSError, TypeError):
|
||||||
|
try:
|
||||||
|
source_file = inspect.getfile(func)
|
||||||
|
with open(source_file, "r", encoding="utf-8") as f:
|
||||||
|
source = f.read()
|
||||||
|
|
||||||
|
# Extract function definition manually
|
||||||
|
source_lines = source.splitlines()
|
||||||
|
func_name = func.__name__
|
||||||
|
start_line = None
|
||||||
|
indent_level = None
|
||||||
|
extracted_lines = []
|
||||||
|
|
||||||
|
for i, line in enumerate(source_lines):
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped.startswith(f"def {func_name}("):
|
||||||
|
start_line = i
|
||||||
|
indent_level = len(line) - len(line.lstrip())
|
||||||
|
extracted_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if start_line is not None:
|
||||||
|
# Stop when indentation level decreases (end of function)
|
||||||
|
current_indent = len(line) - len(line.lstrip())
|
||||||
|
if current_indent <= indent_level and line.strip():
|
||||||
|
break
|
||||||
|
extracted_lines.append(line)
|
||||||
|
|
||||||
|
source = "\n".join(extracted_lines)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to retrieve function source: {e}")
|
||||||
|
|
||||||
|
# Parse source code using AST
|
||||||
|
tree = ast.parse(source)
|
||||||
|
return_vars = []
|
||||||
|
|
||||||
|
class ReturnVisitor(ast.NodeVisitor):
|
||||||
|
def visit_Return(self, node):
|
||||||
|
if isinstance(node.value, ast.Tuple):
|
||||||
|
# Multiple return values
|
||||||
|
return_vars.extend(var.id for var in node.value.elts if isinstance(var, ast.Name))
|
||||||
|
elif isinstance(node.value, ast.Name):
|
||||||
|
# Single return value
|
||||||
|
return_vars.append(node.value.id)
|
||||||
|
|
||||||
|
visitor = ReturnVisitor()
|
||||||
|
visitor.visit(tree)
|
||||||
|
return return_vars
|
||||||
@@ -87,10 +87,13 @@ class FluxSingleTransformerBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||||
@@ -108,7 +111,10 @@ class FluxSingleTransformerBlock(nn.Module):
|
|||||||
if hidden_states.dtype == torch.float16:
|
if hidden_states.dtype == torch.float16:
|
||||||
hidden_states = hidden_states.clip(-65504, 65504)
|
hidden_states = hidden_states.clip(-65504, 65504)
|
||||||
|
|
||||||
return hidden_states
|
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||||
|
[encoder_hidden_states.size(1), hidden_states.size(1) - encoder_hidden_states.size(1)], dim=1
|
||||||
|
)
|
||||||
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
@@ -224,7 +230,7 @@ class FluxTransformerBlock(nn.Module):
|
|||||||
if encoder_hidden_states.dtype == torch.float16:
|
if encoder_hidden_states.dtype == torch.float16:
|
||||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||||
|
|
||||||
return encoder_hidden_states, hidden_states
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FluxTransformer2DModel(
|
class FluxTransformer2DModel(
|
||||||
@@ -517,7 +523,7 @@ class FluxTransformer2DModel(
|
|||||||
|
|
||||||
for index_block, block in enumerate(self.transformer_blocks):
|
for index_block, block in enumerate(self.transformer_blocks):
|
||||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||||
block,
|
block,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
@@ -526,7 +532,7 @@ class FluxTransformer2DModel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states, hidden_states = block(
|
hidden_states, encoder_hidden_states = block(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
@@ -545,20 +551,21 @@ class FluxTransformer2DModel(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
||||||
|
|
||||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
hidden_states = self._gradient_checkpointing_func(
|
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||||
block,
|
block,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
encoder_hidden_states,
|
||||||
temb,
|
temb,
|
||||||
image_rotary_emb,
|
image_rotary_emb,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
hidden_states = block(
|
hidden_states, encoder_hidden_states = block(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
joint_attention_kwargs=joint_attention_kwargs,
|
joint_attention_kwargs=joint_attention_kwargs,
|
||||||
@@ -568,12 +575,7 @@ class FluxTransformer2DModel(
|
|||||||
if controlnet_single_block_samples is not None:
|
if controlnet_single_block_samples is not None:
|
||||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||||
interval_control = int(np.ceil(interval_control))
|
interval_control = int(np.ceil(interval_control))
|
||||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
||||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
|
||||||
+ controlnet_single_block_samples[index_block // interval_control]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
output = self.proj_out(hidden_states)
|
output = self.proj_out(hidden_states)
|
||||||
|
|||||||
Reference in New Issue
Block a user