Sd35 controlnet (#10020)

* add model/pipeline

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
YiYi Xu
2024-11-27 10:44:48 -10:00
committed by GitHub
parent 8d477daed5
commit 75bd1e83cb
4 changed files with 366 additions and 42 deletions

View File

@@ -0,0 +1,185 @@
"""
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
Example:
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py \
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \
--output_path "output/sd35-controlnet-canny" \
--dtype "fp16" # optional, defaults to fp32
```
Or download and convert from HuggingFace repository:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py \
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \
--filename "sd3.5_large_controlnet_canny.safetensors" \
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \
--dtype "fp32" # optional, defaults to fp32
```
Note:
The script supports the following ControlNet types from SD3.5:
- Canny edge detection
- Depth estimation
- Blur detection
The checkpoint files can be downloaded from:
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
"""
import argparse
import safetensors.torch
import torch
from huggingface_hub import hf_hub_download
from diffusers import SD3ControlNetModel
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
parser.add_argument(
"--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint"
)
parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
parser.add_argument(
"--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)"
)
args = parser.parse_args()
def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
if args.filename is None:
raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified")
print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict):
converted_state_dict = {}
# Direct mappings for controlnet blocks
for i in range(19): # 19 controlnet blocks
converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"]
converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"]
# Positional embeddings
converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"]
converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"]
# Time and text embeddings
time_text_mappings = {
"time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight",
"time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias",
"time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight",
"time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias",
"time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight",
"time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias",
"time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight",
"time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias",
}
for new_key, old_key in time_text_mappings.items():
if old_key in original_state_dict:
converted_state_dict[new_key] = original_state_dict[old_key]
# Transformer blocks
for i in range(19):
# Split QKV into separate Q, K, V
qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"]
qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"]
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
block_mappings = {
f"transformer_blocks.{i}.attn.to_q.weight": q,
f"transformer_blocks.{i}.attn.to_q.bias": q_bias,
f"transformer_blocks.{i}.attn.to_k.weight": k,
f"transformer_blocks.{i}.attn.to_k.bias": k_bias,
f"transformer_blocks.{i}.attn.to_v.weight": v,
f"transformer_blocks.{i}.attn.to_v.bias": v_bias,
# Output projections
f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[
f"transformer_blocks.{i}.attn.proj.weight"
],
f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[
f"transformer_blocks.{i}.attn.proj.bias"
],
# Feed forward
f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[
f"transformer_blocks.{i}.mlp.fc1.weight"
],
f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"],
f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"],
f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"],
# Norms
f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[
f"transformer_blocks.{i}.adaLN_modulation.1.weight"
],
f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[
f"transformer_blocks.{i}.adaLN_modulation.1.bias"
],
}
converted_state_dict.update(block_mappings)
return converted_state_dict
def main(args):
original_ckpt = load_original_checkpoint(args)
original_dtype = next(iter(original_ckpt.values())).dtype
# Initialize dtype with fp32 as default
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")
if dtype != original_dtype:
print(
f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution."
)
converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt)
controlnet = SD3ControlNetModel(
patch_size=2,
in_channels=16,
num_layers=19,
attention_head_dim=64,
num_attention_heads=38,
joint_attention_dim=None,
caption_projection_dim=2048,
pooled_projection_dim=2048,
out_channels=16,
pos_embed_max_size=None,
pos_embed_type=None,
use_pos_embed=False,
force_zeros_for_pooled_projection=False,
)
controlnet.load_state_dict(converted_controlnet_state_dict, strict=True)
print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.")
controlnet.to(dtype).save_pretrained(args.output_path)
if __name__ == "__main__":
main(args)

View File

@@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnP
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
from .controlnet import BaseOutput, zero_module
@@ -58,40 +59,60 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
extra_conditioning_channels: int = 0,
dual_attention_layers: Tuple[int, ...] = (),
qk_norm: Optional[str] = None,
pos_embed_type: Optional[str] = "sincos",
use_pos_embed: bool = True,
force_zeros_for_pooled_projection: bool = True,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
)
if use_pos_embed:
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
pos_embed_type=pos_embed_type,
)
else:
self.pos_embed = None
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
if joint_attention_dim is not None:
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(num_layers)
]
)
# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(num_layers)
]
)
else:
self.context_embedder = None
self.transformer_blocks = nn.ModuleList(
[
SD3SingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for _ in range(num_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
@@ -318,9 +339,27 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
if self.pos_embed is not None and hidden_states.ndim != 4:
raise ValueError("hidden_states must be 4D when pos_embed is used")
# SD3.5 8b controlnet does not have a `pos_embed`,
# it use the `pos_embed` from the transformer to process input before passing to controlnet
elif self.pos_embed is None and hidden_states.ndim != 3:
raise ValueError("hidden_states must be 3D when pos_embed is not used")
if self.context_embedder is not None and encoder_hidden_states is None:
raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
elif self.context_embedder is None and encoder_hidden_states is not None:
raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
if self.pos_embed is not None:
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.context_embedder is not None:
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
# add
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
@@ -349,9 +388,13 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
if self.context_embedder is not None:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
else:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states = block(hidden_states, temb)
block_res_samples = block_res_samples + (hidden_states,)

View File

@@ -18,14 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import JointTransformerBlock
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
Attention,
AttentionProcessor,
FusedJointAttnProcessor2_0,
JointAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -33,6 +40,72 @@ from ..modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class SD3SingleTransformerBlock(nn.Module):
r"""
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
return hidden_states
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Stable Diffusion 3.

View File

@@ -858,6 +858,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
controlnet_config = (
self.controlnet.config
if isinstance(self.controlnet, SD3ControlNetModel)
else self.controlnet.nets[0].config
)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
@@ -932,6 +938,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Prepare control image
if controlnet_config.force_zeros_for_pooled_projection:
# instantx sd3 controlnet does not apply shift factor
vae_shift_factor = 0
else:
vae_shift_factor = self.vae.config.shift_factor
if isinstance(self.controlnet, SD3ControlNetModel):
control_image = self.prepare_image(
image=control_image,
@@ -947,8 +958,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height, width = control_image.shape[-2:]
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = control_image * self.vae.config.scaling_factor
control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
elif isinstance(self.controlnet, SD3MultiControlNetModel):
control_images = []
@@ -966,7 +976,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
)
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = control_image_ * self.vae.config.scaling_factor
control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor
control_images.append(control_image_)
@@ -974,11 +984,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
else:
assert False
if controlnet_pooled_projections is None:
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
else:
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -1006,6 +1011,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
if controlnet_config.force_zeros_for_pooled_projection:
# instantx sd3 controlnet used zero pooled projection
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
else:
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
if controlnet_config.joint_attention_dim is not None:
controlnet_encoder_hidden_states = prompt_embeds
else:
# SD35 official 8b controlnet does not use encoder_hidden_states
controlnet_encoder_hidden_states = None
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1025,11 +1042,17 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
if controlnet_config.use_pos_embed is False:
# sd35 (offical) 8b controlnet
controlnet_model_input = self.transformer.pos_embed(latent_model_input)
else:
controlnet_model_input = latent_model_input
# controlnet(s) inference
control_block_samples = self.controlnet(
hidden_states=latent_model_input,
hidden_states=controlnet_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=controlnet_encoder_hidden_states,
pooled_projections=controlnet_pooled_projections,
joint_attention_kwargs=self.joint_attention_kwargs,
controlnet_cond=control_image,