mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Sd35 controlnet (#10020)
* add model/pipeline Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
185
scripts/convert_sd3_controlnet_to_diffusers.py
Normal file
185
scripts/convert_sd3_controlnet_to_diffusers.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
|
||||
|
||||
Example:
|
||||
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
|
||||
```bash
|
||||
python scripts/convert_sd3_controlnet_to_diffusers.py \
|
||||
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \
|
||||
--output_path "output/sd35-controlnet-canny" \
|
||||
--dtype "fp16" # optional, defaults to fp32
|
||||
```
|
||||
|
||||
Or download and convert from HuggingFace repository:
|
||||
```bash
|
||||
python scripts/convert_sd3_controlnet_to_diffusers.py \
|
||||
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \
|
||||
--filename "sd3.5_large_controlnet_canny.safetensors" \
|
||||
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \
|
||||
--dtype "fp32" # optional, defaults to fp32
|
||||
```
|
||||
|
||||
Note:
|
||||
The script supports the following ControlNet types from SD3.5:
|
||||
- Canny edge detection
|
||||
- Depth estimation
|
||||
- Blur detection
|
||||
|
||||
The checkpoint files can be downloaded from:
|
||||
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers import SD3ControlNetModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint"
|
||||
)
|
||||
parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_original_checkpoint(args):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
if args.filename is None:
|
||||
raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified")
|
||||
print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
|
||||
# Direct mappings for controlnet blocks
|
||||
for i in range(19): # 19 controlnet blocks
|
||||
converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"]
|
||||
converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"]
|
||||
|
||||
# Positional embeddings
|
||||
converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"]
|
||||
converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"]
|
||||
|
||||
# Time and text embeddings
|
||||
time_text_mappings = {
|
||||
"time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight",
|
||||
"time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias",
|
||||
"time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight",
|
||||
"time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias",
|
||||
"time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight",
|
||||
"time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias",
|
||||
"time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight",
|
||||
"time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias",
|
||||
}
|
||||
|
||||
for new_key, old_key in time_text_mappings.items():
|
||||
if old_key in original_state_dict:
|
||||
converted_state_dict[new_key] = original_state_dict[old_key]
|
||||
|
||||
# Transformer blocks
|
||||
for i in range(19):
|
||||
# Split QKV into separate Q, K, V
|
||||
qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"]
|
||||
qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"]
|
||||
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
|
||||
|
||||
block_mappings = {
|
||||
f"transformer_blocks.{i}.attn.to_q.weight": q,
|
||||
f"transformer_blocks.{i}.attn.to_q.bias": q_bias,
|
||||
f"transformer_blocks.{i}.attn.to_k.weight": k,
|
||||
f"transformer_blocks.{i}.attn.to_k.bias": k_bias,
|
||||
f"transformer_blocks.{i}.attn.to_v.weight": v,
|
||||
f"transformer_blocks.{i}.attn.to_v.bias": v_bias,
|
||||
# Output projections
|
||||
f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[
|
||||
f"transformer_blocks.{i}.attn.proj.weight"
|
||||
],
|
||||
f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[
|
||||
f"transformer_blocks.{i}.attn.proj.bias"
|
||||
],
|
||||
# Feed forward
|
||||
f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[
|
||||
f"transformer_blocks.{i}.mlp.fc1.weight"
|
||||
],
|
||||
f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"],
|
||||
f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"],
|
||||
f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"],
|
||||
# Norms
|
||||
f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[
|
||||
f"transformer_blocks.{i}.adaLN_modulation.1.weight"
|
||||
],
|
||||
f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[
|
||||
f"transformer_blocks.{i}.adaLN_modulation.1.bias"
|
||||
],
|
||||
}
|
||||
converted_state_dict.update(block_mappings)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
original_ckpt = load_original_checkpoint(args)
|
||||
original_dtype = next(iter(original_ckpt.values())).dtype
|
||||
|
||||
# Initialize dtype with fp32 as default
|
||||
if args.dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif args.dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif args.dtype == "fp32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")
|
||||
|
||||
if dtype != original_dtype:
|
||||
print(
|
||||
f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution."
|
||||
)
|
||||
|
||||
converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt)
|
||||
|
||||
controlnet = SD3ControlNetModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
num_layers=19,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=38,
|
||||
joint_attention_dim=None,
|
||||
caption_projection_dim=2048,
|
||||
pooled_projection_dim=2048,
|
||||
out_channels=16,
|
||||
pos_embed_max_size=None,
|
||||
pos_embed_type=None,
|
||||
use_pos_embed=False,
|
||||
force_zeros_for_pooled_projection=False,
|
||||
)
|
||||
|
||||
controlnet.load_state_dict(converted_controlnet_state_dict, strict=True)
|
||||
|
||||
print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.")
|
||||
controlnet.to(dtype).save_pretrained(args.output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnP
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
|
||||
|
||||
@@ -58,40 +59,60 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
extra_conditioning_channels: int = 0,
|
||||
dual_attention_layers: Tuple[int, ...] = (),
|
||||
qk_norm: Optional[str] = None,
|
||||
pos_embed_type: Optional[str] = "sincos",
|
||||
use_pos_embed: bool = True,
|
||||
force_zeros_for_pooled_projection: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
)
|
||||
if use_pos_embed:
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
pos_embed_type=pos_embed_type,
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||
if joint_attention_dim is not None:
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
use_dual_attention=True if i in dual_attention_layers else False,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
use_dual_attention=True if i in dual_attention_layers else False,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.context_embedder = None
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SD3SingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
@@ -318,9 +339,27 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
if self.pos_embed is not None and hidden_states.ndim != 4:
|
||||
raise ValueError("hidden_states must be 4D when pos_embed is used")
|
||||
|
||||
# SD3.5 8b controlnet does not have a `pos_embed`,
|
||||
# it use the `pos_embed` from the transformer to process input before passing to controlnet
|
||||
elif self.pos_embed is None and hidden_states.ndim != 3:
|
||||
raise ValueError("hidden_states must be 3D when pos_embed is not used")
|
||||
|
||||
if self.context_embedder is not None and encoder_hidden_states is None:
|
||||
raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
|
||||
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
|
||||
elif self.context_embedder is None and encoder_hidden_states is not None:
|
||||
raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
|
||||
|
||||
if self.pos_embed is not None:
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if self.context_embedder is not None:
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# add
|
||||
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
|
||||
@@ -349,9 +388,13 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
if self.context_embedder is not None:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
else:
|
||||
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
|
||||
hidden_states = block(hidden_states, temb)
|
||||
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
|
||||
|
||||
@@ -18,14 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import JointTransformerBlock
|
||||
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ...models.attention import FeedForward, JointTransformerBlock
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FusedJointAttnProcessor2_0,
|
||||
JointAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
@@ -33,6 +40,72 @@ from ..modeling_outputs import Transformer2DModelOutput
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SD3SingleTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
processor = JointAttnProcessor2_0()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
# Attention.
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
The Transformer model introduced in Stable Diffusion 3.
|
||||
|
||||
@@ -858,6 +858,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
controlnet_config = (
|
||||
self.controlnet.config
|
||||
if isinstance(self.controlnet, SD3ControlNetModel)
|
||||
else self.controlnet.nets[0].config
|
||||
)
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
@@ -932,6 +938,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 3. Prepare control image
|
||||
if controlnet_config.force_zeros_for_pooled_projection:
|
||||
# instantx sd3 controlnet does not apply shift factor
|
||||
vae_shift_factor = 0
|
||||
else:
|
||||
vae_shift_factor = self.vae.config.shift_factor
|
||||
if isinstance(self.controlnet, SD3ControlNetModel):
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
@@ -947,8 +958,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = control_image * self.vae.config.scaling_factor
|
||||
|
||||
control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
|
||||
elif isinstance(self.controlnet, SD3MultiControlNetModel):
|
||||
control_images = []
|
||||
|
||||
@@ -966,7 +976,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
)
|
||||
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = control_image_ * self.vae.config.scaling_factor
|
||||
control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
control_images.append(control_image_)
|
||||
|
||||
@@ -974,11 +984,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
else:
|
||||
assert False
|
||||
|
||||
if controlnet_pooled_projections is None:
|
||||
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
|
||||
else:
|
||||
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@@ -1006,6 +1011,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
|
||||
|
||||
if controlnet_config.force_zeros_for_pooled_projection:
|
||||
# instantx sd3 controlnet used zero pooled projection
|
||||
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
|
||||
else:
|
||||
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
|
||||
|
||||
if controlnet_config.joint_attention_dim is not None:
|
||||
controlnet_encoder_hidden_states = prompt_embeds
|
||||
else:
|
||||
# SD35 official 8b controlnet does not use encoder_hidden_states
|
||||
controlnet_encoder_hidden_states = None
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1025,11 +1042,17 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
if controlnet_config.use_pos_embed is False:
|
||||
# sd35 (offical) 8b controlnet
|
||||
controlnet_model_input = self.transformer.pos_embed(latent_model_input)
|
||||
else:
|
||||
controlnet_model_input = latent_model_input
|
||||
|
||||
# controlnet(s) inference
|
||||
control_block_samples = self.controlnet(
|
||||
hidden_states=latent_model_input,
|
||||
hidden_states=controlnet_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states=controlnet_encoder_hidden_states,
|
||||
pooled_projections=controlnet_pooled_projections,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
controlnet_cond=control_image,
|
||||
|
||||
Reference in New Issue
Block a user