mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
1 Commits
qwen-image
...
integratio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
98771d3611 |
30
src/diffusers/hooks/_common.py
Normal file
30
src/diffusers/hooks/_common.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
262
src/diffusers/hooks/first_block_cache.py
Normal file
262
src/diffusers/hooks/first_block_cache.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .utils import _extract_return_information
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
||||
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirstBlockCacheConfig:
|
||||
r"""
|
||||
Configuration for [First Block
|
||||
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.05`):
|
||||
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
||||
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
|
||||
poorer generation quality. A lower threshold may not result in significant generation speedup. The
|
||||
threshold is compared against the absmean difference of the residuals between the current and cached
|
||||
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
|
||||
skipped.
|
||||
"""
|
||||
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState:
|
||||
def __init__(self) -> None:
|
||||
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.head_block_residual: torch.Tensor = None
|
||||
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
self.tail_block_residuals = None
|
||||
self.should_compute = True
|
||||
|
||||
def __repr__(self):
|
||||
return f"FirstBlockCacheSharedState(cache={self.cache})"
|
||||
|
||||
|
||||
class FBCHeadBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
|
||||
self.shared_state = shared_state
|
||||
self.threshold = threshold
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
inputs = inspect.signature(module.__class__.forward)
|
||||
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
|
||||
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
|
||||
|
||||
try:
|
||||
outputs = _extract_return_information(module.__class__.forward)
|
||||
outputs_index_to_str = dict(enumerate(outputs))
|
||||
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
|
||||
except RuntimeError:
|
||||
logger.error(f"Failed to extract return information for {module.__class__}")
|
||||
raise NotImplementedError(
|
||||
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
|
||||
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
|
||||
f"in order for us to add support for this module."
|
||||
)
|
||||
|
||||
self._inputs_index_to_str = inputs_index_to_str
|
||||
self._inputs_str_to_index = inputs_str_to_index
|
||||
self._outputs_index_to_str = outputs_index_to_str
|
||||
self._outputs_str_to_index = outputs_str_to_index
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
|
||||
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
|
||||
original_hs = kwargs.get("hidden_states", None)
|
||||
original_ehs = kwargs.get("encoder_hidden_states", None)
|
||||
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
|
||||
if ehs_input_idx is not None:
|
||||
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
|
||||
|
||||
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
|
||||
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
|
||||
assert (ehs_input_idx is None) == (ehs_output_idx is None)
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
hs_residual = None
|
||||
if isinstance(output, tuple):
|
||||
hs_residual = output[hs_output_idx] - original_hs
|
||||
else:
|
||||
hs_residual = output - original_hs
|
||||
|
||||
should_compute = self._should_compute_remaining_blocks(hs_residual)
|
||||
self.shared_state.should_compute = should_compute
|
||||
|
||||
hs, ehs = None, None
|
||||
if not should_compute:
|
||||
# Apply caching
|
||||
logger.info("Skipping forward pass through remaining blocks")
|
||||
hs = self.shared_state.tail_block_residuals[0] + output[hs_output_idx]
|
||||
if ehs_output_idx is not None:
|
||||
ehs = self.shared_state.tail_block_residuals[1] + output[ehs_output_idx]
|
||||
|
||||
if isinstance(output, tuple):
|
||||
return_output = [None] * len(output)
|
||||
return_output[hs_output_idx] = hs
|
||||
return_output[ehs_output_idx] = ehs
|
||||
return_output = tuple(return_output)
|
||||
else:
|
||||
return_output = hs
|
||||
return return_output
|
||||
else:
|
||||
logger.info("Computing forward pass through remaining blocks")
|
||||
if isinstance(output, tuple):
|
||||
head_block_output = [None] * len(output)
|
||||
head_block_output[0] = output[hs_output_idx]
|
||||
head_block_output[1] = output[ehs_output_idx]
|
||||
else:
|
||||
head_block_output = output
|
||||
self.shared_state.head_block_output = head_block_output
|
||||
self.shared_state.head_block_residual = hs_residual
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.shared_state.reset()
|
||||
return module
|
||||
|
||||
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
|
||||
if self.shared_state.head_block_residual is None:
|
||||
return True
|
||||
prev_hs_residual = self.shared_state.head_block_residual
|
||||
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
|
||||
prev_hs_mean = prev_hs_residual.abs().mean()
|
||||
diff = (hs_absmean / prev_hs_mean).item()
|
||||
logger.info(f"Diff: {diff}, Threshold: {self.threshold}")
|
||||
return diff > self.threshold
|
||||
|
||||
|
||||
class FBCBlockHook(ModelHook):
|
||||
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
|
||||
super().__init__()
|
||||
self.shared_state = shared_state
|
||||
self.is_tail = is_tail
|
||||
|
||||
def initialize_hook(self, module):
|
||||
inputs = inspect.signature(module.__class__.forward)
|
||||
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
|
||||
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
|
||||
|
||||
try:
|
||||
outputs = _extract_return_information(module.__class__.forward)
|
||||
outputs_index_to_str = dict(enumerate(outputs))
|
||||
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
|
||||
except RuntimeError:
|
||||
logger.error(f"Failed to extract return information for {module.__class__}")
|
||||
raise NotImplementedError(
|
||||
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
|
||||
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
|
||||
f"in order for us to add support for this module."
|
||||
)
|
||||
|
||||
self._inputs_index_to_str = inputs_index_to_str
|
||||
self._inputs_str_to_index = inputs_str_to_index
|
||||
self._outputs_index_to_str = outputs_index_to_str
|
||||
self._outputs_str_to_index = outputs_str_to_index
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
|
||||
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
|
||||
original_hs = kwargs.get("hidden_states", None)
|
||||
original_ehs = kwargs.get("encoder_hidden_states", None)
|
||||
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
|
||||
if ehs_input_idx is not None:
|
||||
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
|
||||
|
||||
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
|
||||
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
|
||||
assert (ehs_input_idx is None) == (ehs_output_idx is None)
|
||||
|
||||
if self.shared_state.should_compute:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if self.is_tail:
|
||||
hs_residual, ehs_residual = None, None
|
||||
if isinstance(output, tuple):
|
||||
hs_residual = output[hs_output_idx] - self.shared_state.head_block_output[0]
|
||||
ehs_residual = output[ehs_output_idx] - self.shared_state.head_block_output[1]
|
||||
else:
|
||||
hs_residual = output - self.shared_state.head_block_output
|
||||
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
|
||||
return output
|
||||
|
||||
output_count = len(self._outputs_index_to_str.keys())
|
||||
return_output = [None] * output_count if output_count > 1 else original_hs
|
||||
if output_count == 1:
|
||||
return_output = original_hs
|
||||
else:
|
||||
return_output[hs_output_idx] = original_hs
|
||||
return_output[ehs_output_idx] = original_ehs
|
||||
return return_output
|
||||
|
||||
|
||||
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||
shared_state = FBCSharedBlockState()
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for block in submodule:
|
||||
remaining_blocks.append((name, block))
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
|
||||
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Apply FBCBlockHook to '{name}'")
|
||||
apply_fbc_block_hook(block, shared_state)
|
||||
|
||||
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
|
||||
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
|
||||
|
||||
|
||||
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCHeadBlockHook(state, threshold)
|
||||
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCBlockHook(state, is_tail)
|
||||
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||
@@ -20,19 +20,18 @@ import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..utils import logging
|
||||
from ._common import (
|
||||
_ATTENTION_CLASSES,
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
)
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyramidAttentionBroadcastConfig:
|
||||
r"""
|
||||
@@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
|
||||
59
src/diffusers/hooks/utils.py
Normal file
59
src/diffusers/hooks/utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import List
|
||||
|
||||
|
||||
def _extract_return_information(func) -> List[str]:
|
||||
"""Extracts return variable names in order from a function."""
|
||||
try:
|
||||
source = inspect.getsource(func)
|
||||
source = textwrap.dedent(source) # Modify indentation to make parsing compatible
|
||||
except (OSError, TypeError):
|
||||
try:
|
||||
source_file = inspect.getfile(func)
|
||||
with open(source_file, "r", encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
|
||||
# Extract function definition manually
|
||||
source_lines = source.splitlines()
|
||||
func_name = func.__name__
|
||||
start_line = None
|
||||
indent_level = None
|
||||
extracted_lines = []
|
||||
|
||||
for i, line in enumerate(source_lines):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith(f"def {func_name}("):
|
||||
start_line = i
|
||||
indent_level = len(line) - len(line.lstrip())
|
||||
extracted_lines.append(line)
|
||||
continue
|
||||
|
||||
if start_line is not None:
|
||||
# Stop when indentation level decreases (end of function)
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
if current_indent <= indent_level and line.strip():
|
||||
break
|
||||
extracted_lines.append(line)
|
||||
|
||||
source = "\n".join(extracted_lines)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to retrieve function source: {e}")
|
||||
|
||||
# Parse source code using AST
|
||||
tree = ast.parse(source)
|
||||
return_vars = []
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
if isinstance(node.value, ast.Tuple):
|
||||
# Multiple return values
|
||||
return_vars.extend(var.id for var in node.value.elts if isinstance(var, ast.Name))
|
||||
elif isinstance(node.value, ast.Name):
|
||||
# Single return value
|
||||
return_vars.append(node.value.id)
|
||||
|
||||
visitor = ReturnVisitor()
|
||||
visitor.visit(tree)
|
||||
return return_vars
|
||||
@@ -87,10 +87,13 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
@@ -108,7 +111,10 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[encoder_hidden_states.size(1), hidden_states.size(1) - encoder_hidden_states.size(1)], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -224,7 +230,7 @@ class FluxTransformerBlock(nn.Module):
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
@@ -517,7 +523,7 @@ class FluxTransformer2DModel(
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
@@ -526,7 +532,7 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
@@ -545,20 +551,21 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
@@ -568,12 +575,7 @@ class FluxTransformer2DModel(
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user