mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
114 Commits
controlnet
...
test-v-upd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56e8fca572 | ||
|
|
c5941a26a4 | ||
|
|
8bc42512fe | ||
|
|
55b4d09080 | ||
|
|
c452d9c042 | ||
|
|
ee9f7d2493 | ||
|
|
8620851aa0 | ||
|
|
90d8e832f8 | ||
|
|
18930e0b85 | ||
|
|
847bd0a479 | ||
|
|
3178b16b17 | ||
|
|
a08ef009d1 | ||
|
|
804bdebe51 | ||
|
|
a193e49dff | ||
|
|
c9d1727613 | ||
|
|
82cf60828f | ||
|
|
26ed460265 | ||
|
|
403a81c30d | ||
|
|
1b3cf2db5e | ||
|
|
b8d84c4320 | ||
|
|
3fbe123d84 | ||
|
|
f7cf8c338c | ||
|
|
ab8076f234 | ||
|
|
7b6a0d48c6 | ||
|
|
6adae54046 | ||
|
|
af85fb1bc1 | ||
|
|
760333d524 | ||
|
|
f651c12ef8 | ||
|
|
d614a33a09 | ||
|
|
13b646edd3 | ||
|
|
cb49cbdd29 | ||
|
|
1ce8ff51e6 | ||
|
|
fdd182f335 | ||
|
|
2a46326c25 | ||
|
|
e34e9d9a33 | ||
|
|
96af28f92b | ||
|
|
6827a1dc6a | ||
|
|
c3bdeb8a4c | ||
|
|
cf70b9a0b4 | ||
|
|
712b9950c5 | ||
|
|
21148de853 | ||
|
|
d930977656 | ||
|
|
268ffea0e7 | ||
|
|
8bcf43d52a | ||
|
|
b071aaa719 | ||
|
|
5316fb5107 | ||
|
|
9af07d1d5c | ||
|
|
d0017d9b70 | ||
|
|
0cf6c6b291 | ||
|
|
df986274d6 | ||
|
|
7ddd14bd94 | ||
|
|
4346ddd402 | ||
|
|
9da55b381c | ||
|
|
4d4469ee87 | ||
|
|
f9954a0e7b | ||
|
|
e7798333c4 | ||
|
|
efb1e5e1d8 | ||
|
|
beaaf18b2c | ||
|
|
132fe97bf4 | ||
|
|
2f35e8c94c | ||
|
|
b336529573 | ||
|
|
3e47d3c8ed | ||
|
|
122a6bd390 | ||
|
|
37c428a79c | ||
|
|
eefed8ab6b | ||
|
|
05eaec2d39 | ||
|
|
e68424378f | ||
|
|
24b5c4360c | ||
|
|
0c4192b537 | ||
|
|
dff26ce8af | ||
|
|
9f22651c1f | ||
|
|
d8c9e67aac | ||
|
|
6c28367b1a | ||
|
|
f9def2aeed | ||
|
|
576fa1c7dc | ||
|
|
f1457b7e1d | ||
|
|
1f34311eec | ||
|
|
f976f5a31e | ||
|
|
8e1851a16a | ||
|
|
6c69c7a0d2 | ||
|
|
6481e9495f | ||
|
|
8c3fd58c85 | ||
|
|
9117547ee0 | ||
|
|
af1e86af8d | ||
|
|
29551f8e30 | ||
|
|
661033171b | ||
|
|
20efe541c5 | ||
|
|
5a523e21c6 | ||
|
|
b0fc4fd4cb | ||
|
|
678d19fa18 | ||
|
|
c8ec445964 | ||
|
|
ffd9e26a65 | ||
|
|
6f87490408 | ||
|
|
9c9d46763b | ||
|
|
47684dab43 | ||
|
|
5218f46173 | ||
|
|
8ee280773f | ||
|
|
85846f7450 | ||
|
|
28dee6e735 | ||
|
|
165ed7c5d5 | ||
|
|
d4cdfa33f5 | ||
|
|
1bd09b1489 | ||
|
|
edf7121ec7 | ||
|
|
7b64d3a17b | ||
|
|
c93606c93c | ||
|
|
5df09ef355 | ||
|
|
ac9473153c | ||
|
|
ee9d7b8ecd | ||
|
|
669824e5bb | ||
|
|
45c9b56bf7 | ||
|
|
cad51d45d1 | ||
|
|
7de5d7c6fd | ||
|
|
58883ee085 | ||
|
|
2f5648177e |
730
scripts/convert_svd_to_diffusers.py
Normal file
730
scripts/convert_svd_to_diffusers.py
Normal file
@@ -0,0 +1,730 @@
|
||||
from diffusers.utils import is_accelerate_available, logging
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
if controlnet:
|
||||
unet_params = original_config.model.params.control_stage_config.params
|
||||
else:
|
||||
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
else:
|
||||
unet_params = original_config.model.params.network_config.params
|
||||
|
||||
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = (
|
||||
"CrossAttnDownBlockSpatioTemporal"
|
||||
if resolution in unet_params.attention_resolutions
|
||||
else "DownBlockSpatioTemporal"
|
||||
)
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = (
|
||||
"CrossAttnUpBlockSpatioTemporal"
|
||||
if resolution in unet_params.attention_resolutions
|
||||
else "UpBlockSpatioTemporal"
|
||||
)
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
if unet_params.transformer_depth is not None:
|
||||
transformer_layers_per_block = (
|
||||
unet_params.transformer_depth
|
||||
if isinstance(unet_params.transformer_depth, int)
|
||||
else list(unet_params.transformer_depth)
|
||||
)
|
||||
else:
|
||||
transformer_layers_per_block = 1
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
use_linear_projection = (
|
||||
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||
)
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||
|
||||
class_embed_type = None
|
||||
addition_embed_type = None
|
||||
addition_time_embed_dim = None
|
||||
projection_class_embeddings_input_dim = None
|
||||
context_dim = None
|
||||
|
||||
if unet_params.context_dim is not None:
|
||||
context_dim = (
|
||||
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
)
|
||||
|
||||
if "num_classes" in unet_params:
|
||||
if unet_params.num_classes == "sequential":
|
||||
addition_time_embed_dim = 256
|
||||
assert "adm_in_channels" in unet_params
|
||||
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
||||
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params.in_channels,
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
"class_embed_type": class_embed_type,
|
||||
"addition_embed_type": addition_embed_type,
|
||||
"addition_time_embed_dim": addition_time_embed_dim,
|
||||
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
||||
"transformer_layers_per_block": transformer_layers_per_block,
|
||||
}
|
||||
|
||||
if "disable_self_attentions" in unet_params:
|
||||
config["only_cross_attention"] = unet_params.disable_self_attentions
|
||||
|
||||
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
|
||||
config["num_class_embeds"] = unet_params.num_classes
|
||||
|
||||
if controlnet:
|
||||
config["conditioning_channels"] = unet_params.hint_channels
|
||||
else:
|
||||
config["out_channels"] = unet_params.out_channels
|
||||
config["up_block_types"] = tuple(up_block_types)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def assign_to_checkpoint(
|
||||
paths,
|
||||
checkpoint,
|
||||
old_checkpoint,
|
||||
attention_paths_to_split=None,
|
||||
additional_replacements=None,
|
||||
config=None,
|
||||
mid_block_suffix="",
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
||||
attention layers, and takes into account additional replacements that may arise.
|
||||
|
||||
Assigns the weights to the new checkpoint.
|
||||
"""
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||
|
||||
if mid_block_suffix is not None:
|
||||
mid_block_suffix = f".{mid_block_suffix}"
|
||||
else:
|
||||
mid_block_suffix = ""
|
||||
|
||||
for path in paths:
|
||||
new_path = path["new"]
|
||||
|
||||
# These have already been assigned
|
||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}")
|
||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||
new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}")
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight":
|
||||
print("yeyy")
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
|
||||
shape = old_checkpoint[path["old"]].shape
|
||||
if is_attn_weight and len(shape) == 3:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
elif is_attn_weight and len(shape) == 4:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||
|
||||
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||
|
||||
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
new_item = new_item.replace("time_stack", "temporal_transformer_blocks")
|
||||
|
||||
new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias")
|
||||
new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight")
|
||||
new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias")
|
||||
new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight")
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
"""
|
||||
if n_shave_prefix_segments >= 0:
|
||||
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||
else:
|
||||
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||
|
||||
|
||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item.replace("in_layers.0", "norm1")
|
||||
new_item = new_item.replace("in_layers.2", "conv1")
|
||||
|
||||
new_item = new_item.replace("out_layers.0", "norm2")
|
||||
new_item = new_item.replace("out_layers.3", "conv2")
|
||||
|
||||
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
||||
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||
|
||||
new_item = new_item.replace("time_stack.", "")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(
|
||||
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
|
||||
):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
|
||||
if skip_extract_state_dict:
|
||||
unet_state_dict = checkpoint
|
||||
else:
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
unet_key = "model.diffusion_model."
|
||||
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
logger.warning(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
logger.warning(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||
|
||||
if config["class_embed_type"] is None:
|
||||
# No parameters to port
|
||||
...
|
||||
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
||||
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
||||
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
||||
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||
else:
|
||||
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
||||
|
||||
# if config["addition_embed_type"] == "text_time":
|
||||
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
||||
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
||||
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
||||
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
||||
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_output_blocks)
|
||||
}
|
||||
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||
|
||||
spatial_resnets = [
|
||||
key
|
||||
for key in input_blocks[i]
|
||||
if f"input_blocks.{i}.0" in key
|
||||
and (
|
||||
f"input_blocks.{i}.0.op" not in key
|
||||
and f"input_blocks.{i}.0.time_stack" not in key
|
||||
and f"input_blocks.{i}.0.time_mixer" not in key
|
||||
)
|
||||
]
|
||||
temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key]
|
||||
# import ipdb; ipdb.set_trace()
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
|
||||
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.bias"
|
||||
)
|
||||
|
||||
paths = renew_resnet_paths(spatial_resnets)
|
||||
meta_path = {
|
||||
"old": f"input_blocks.{i}.0",
|
||||
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
paths = renew_resnet_paths(temporal_resnets)
|
||||
meta_path = {
|
||||
"old": f"input_blocks.{i}.0",
|
||||
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
# TODO resnet time_mixer.mix_factor
|
||||
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[
|
||||
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
||||
# import ipdb; ipdb.set_trace()
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
attentions = middle_blocks[1]
|
||||
resnet_1 = middle_blocks[2]
|
||||
|
||||
resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key]
|
||||
resnet_0_paths = renew_resnet_paths(resnet_0_spatial)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
assign_to_checkpoint(
|
||||
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
|
||||
)
|
||||
|
||||
resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key]
|
||||
resnet_0_paths = renew_resnet_paths(resnet_0_temporal)
|
||||
assign_to_checkpoint(
|
||||
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
|
||||
)
|
||||
|
||||
resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key]
|
||||
resnet_1_paths = renew_resnet_paths(resnet_1_spatial)
|
||||
assign_to_checkpoint(
|
||||
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
|
||||
)
|
||||
|
||||
resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key]
|
||||
resnet_1_paths = renew_resnet_paths(resnet_1_temporal)
|
||||
assign_to_checkpoint(
|
||||
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
|
||||
)
|
||||
|
||||
new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
|
||||
"middle_block.0.time_mixer.mix_factor"
|
||||
]
|
||||
new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[
|
||||
"middle_block.2.time_mixer.mix_factor"
|
||||
]
|
||||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(
|
||||
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
||||
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||
output_block_list = {}
|
||||
|
||||
for layer in output_block_layers:
|
||||
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
||||
if layer_id in output_block_list:
|
||||
output_block_list[layer_id].append(layer_name)
|
||||
else:
|
||||
output_block_list[layer_id] = [layer_name]
|
||||
|
||||
if len(output_block_list) > 1:
|
||||
spatial_resnets = [
|
||||
key
|
||||
for key in output_blocks[i]
|
||||
if f"output_blocks.{i}.0" in key
|
||||
and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key)
|
||||
]
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key]
|
||||
|
||||
paths = renew_resnet_paths(spatial_resnets)
|
||||
meta_path = {
|
||||
"old": f"output_blocks.{i}.0",
|
||||
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
paths = renew_resnet_paths(temporal_resnets)
|
||||
meta_path = {
|
||||
"old": f"output_blocks.{i}.0",
|
||||
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[
|
||||
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
|
||||
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.bias"
|
||||
]
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
attentions = []
|
||||
|
||||
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key]
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
meta_path = {
|
||||
"old": f"output_blocks.{i}.1",
|
||||
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
else:
|
||||
spatial_layers = [
|
||||
layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer
|
||||
]
|
||||
resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for path in resnet_0_paths:
|
||||
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||
new_path = ".".join(
|
||||
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]]
|
||||
)
|
||||
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
temporal_layers = [
|
||||
layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key
|
||||
]
|
||||
resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for path in resnet_0_paths:
|
||||
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||
new_path = ".".join(
|
||||
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]]
|
||||
)
|
||||
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
|
||||
f"output_blocks.{str(i)}.0.time_mixer.mix_factor"
|
||||
]
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
# Temporal resnet
|
||||
new_item = old_item.replace("in_layers.0", "norm1")
|
||||
new_item = new_item.replace("in_layers.2", "conv1")
|
||||
|
||||
new_item = new_item.replace("out_layers.0", "norm2")
|
||||
new_item = new_item.replace("out_layers.3", "conv2")
|
||||
|
||||
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||
|
||||
new_item = new_item.replace("time_stack.", "temporal_res_block.")
|
||||
|
||||
# Spatial resnet
|
||||
new_item = new_item.replace("conv1", "spatial_res_block.conv1")
|
||||
new_item = new_item.replace("norm1", "spatial_res_block.norm1")
|
||||
|
||||
new_item = new_item.replace("conv2", "spatial_res_block.conv2")
|
||||
new_item = new_item.replace("norm2", "spatial_res_block.norm2")
|
||||
|
||||
new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut")
|
||||
|
||||
new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
||||
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
||||
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
||||
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
||||
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
||||
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
||||
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
||||
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
||||
new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"]
|
||||
new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"]
|
||||
|
||||
# new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
||||
# new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
||||
# new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
||||
# new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
)
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
)
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||
]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
return new_checkpoint
|
||||
@@ -76,6 +76,7 @@ else:
|
||||
[
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderTiny",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
@@ -92,6 +93,7 @@ else:
|
||||
"UNet2DModel",
|
||||
"UNet3DConditionModel",
|
||||
"UNetMotionModel",
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"VQModel",
|
||||
]
|
||||
)
|
||||
@@ -267,6 +269,7 @@ else:
|
||||
"StableDiffusionPix2PixZeroPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableDiffusionVideoPipeline",
|
||||
"StableDiffusionXLAdapterPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
@@ -446,6 +449,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .models import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
@@ -462,6 +466,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
UNetMotionModel,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
VQModel,
|
||||
)
|
||||
from .optimization import (
|
||||
@@ -616,6 +621,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionVideoPipeline,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
|
||||
@@ -14,7 +14,12 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
|
||||
from ..utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
_LazyModule,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {}
|
||||
@@ -23,6 +28,7 @@ if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
@@ -38,6 +44,7 @@ if is_torch_available():
|
||||
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
if is_flax_available():
|
||||
@@ -51,6 +58,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .controlnet import ControlNetModel
|
||||
@@ -66,6 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .unet_kandi3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
|
||||
@@ -194,7 +194,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
if not self.use_ada_layer_norm_single:
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||
@@ -339,6 +344,181 @@ class BasicTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class TemporalBasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block for video like data.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_mix_inner_dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.is_res = dim == time_mix_inner_dim
|
||||
|
||||
self.norm_in = nn.LayerNorm(dim, eps=norm_eps)
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.norm_in = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.ff_in = FeedForward(
|
||||
dim,
|
||||
dim_out=time_mix_inner_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="geglu",
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
|
||||
self.norm1 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
|
||||
self.attn1 = Attention(
|
||||
query_dim=time_mix_inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=None,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=time_mix_inner_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
|
||||
self.ff = FeedForward(
|
||||
time_mix_inner_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="geglu",
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
num_frames: int,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
batch_frames, seq_length, channels = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
hidden_states = self.ff_in(hidden_states)
|
||||
if self.is_res:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.is_res:
|
||||
hidden_states = ff_output + hidden_states
|
||||
else:
|
||||
hidden_states = ff_output
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
672
src/diffusers/models/autoencoder_kl_temporal_decoder.py
Normal file
672
src/diffusers/models/autoencoder_kl_temporal_decoder.py
Normal file
@@ -0,0 +1,672 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalVAEMixin
|
||||
from ..utils import BaseOutput, is_torch_version
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
class TemporalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
512,
|
||||
),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
norm_type: str = "group", # group, spatial
|
||||
alpha: float = 0.0,
|
||||
merge_strategy: str = "learned",
|
||||
conv_out_kernel_size=(3, 1, 1),
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||
temb_channels = in_channels if norm_type == "spatial" else None
|
||||
self.mid_block = MidBlockTemporalDecoder(
|
||||
num_layers=self.layers_per_block,
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
temporal_resnet_eps=1e-5,
|
||||
resnet_act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
temb_channels=temb_channels,
|
||||
resnet_time_scale_shift=norm_type,
|
||||
merge_factor=alpha,
|
||||
merge_strategy=merge_strategy,
|
||||
)
|
||||
|
||||
# up
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
up_block = UpBlockTemporalDecoder(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
temporal_resnet_eps=1e-5,
|
||||
resnet_act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
attention_head_dim=output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resnet_time_scale_shift=norm_type,
|
||||
merge_factor=alpha,
|
||||
merge_strategy=merge_strategy,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
|
||||
if isinstance(conv_out_kernel_size, Iterable):
|
||||
padding = [int(k // 2) for k in conv_out_kernel_size]
|
||||
else:
|
||||
padding = int(conv_out_kernel_size // 2)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
in_channels=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.time_conv_out = torch.nn.Conv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=conv_out_kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
image_only_indicator: torch.FloatTensor,
|
||||
num_frames: int = 1,
|
||||
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
sample,
|
||||
image_only_indicator,
|
||||
latent_embeds,
|
||||
num_frames,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
image_only_indicator,
|
||||
latent_embeds,
|
||||
num_frames,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
sample,
|
||||
image_only_indicator,
|
||||
latent_embeds,
|
||||
num_frames,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
image_only_indicator,
|
||||
latent_embeds,
|
||||
num_frames,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
temb=latent_embeds,
|
||||
num_frames=num_frames,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = up_block(
|
||||
sample,
|
||||
temb=latent_embeds,
|
||||
num_frames=num_frames,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# post-process
|
||||
if latent_embeds is None:
|
||||
sample = self.conv_norm_out(sample)
|
||||
else:
|
||||
sample = self.conv_norm_out(sample, latent_embeds)
|
||||
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
batch_frames, channels, height, width = sample.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||
sample = self.time_conv_out(sample)
|
||||
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoencoderKLOutput(BaseOutput):
|
||||
"""
|
||||
Output of AutoencoderKL encoding method.
|
||||
|
||||
Args:
|
||||
latent_dist (`DiagonalGaussianDistribution`):
|
||||
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
||||
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
||||
"""
|
||||
|
||||
latent_dist: "DiagonalGaussianDistribution"
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
force_upcast (`bool`, *optional*, default to `True`):
|
||||
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
||||
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
||||
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
scaling_factor: float = 0.18215,
|
||||
force_upcast: float = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=True,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = TemporalDecoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, TemporalDecoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
def disable_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.enable_tiling(False)
|
||||
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x, return_dict=return_dict)
|
||||
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self.encoder(x)
|
||||
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(
|
||||
self, z: torch.FloatTensor, num_frames: int, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
batch_size = z.shape[0] // num_frames
|
||||
# TODO: dont hardcode this
|
||||
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
|
||||
dec = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self,
|
||||
z: torch.FloatTensor,
|
||||
num_frames: int,
|
||||
return_dict: bool = True,
|
||||
generator=None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.FloatTensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice, num_frames // 2).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z, num_frames).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||
`tuple` is returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[
|
||||
:,
|
||||
:,
|
||||
i : i + self.tile_sample_min_size,
|
||||
j : j + self.tile_sample_min_size,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
moments = torch.cat(result_rows, dim=2)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.FloatTensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, z.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[
|
||||
:,
|
||||
:,
|
||||
i : i + self.tile_latent_min_size,
|
||||
j : j + self.tile_latent_min_size,
|
||||
]
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
dec = torch.cat(result_rows, dim=2)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
num_frames: int = 1,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
|
||||
dec = self.decode(z, num_frames=num_frames).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -165,7 +165,10 @@ class Upsample2D(nn.Module):
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
output_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
@@ -379,7 +382,11 @@ class FirUpsample2D(nn.Module):
|
||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
inverse_conv = F.conv_transpose2d(
|
||||
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
|
||||
hidden_states,
|
||||
weight,
|
||||
stride=stride,
|
||||
output_padding=output_padding,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
output = upfirdn2d_native(
|
||||
@@ -530,7 +537,14 @@ class KDownsample2D(nn.Module):
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
@@ -553,7 +567,14 @@ class KUpsample2D(nn.Module):
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
@@ -690,11 +711,19 @@ class ResnetBlock2D(nn.Module):
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = conv_cls(
|
||||
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
||||
in_channels,
|
||||
conv_2d_out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=conv_shortcut_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
@@ -866,7 +895,10 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
|
||||
|
||||
def upsample_2d(
|
||||
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
@@ -910,7 +942,10 @@ def upsample_2d(
|
||||
|
||||
|
||||
def downsample_2d(
|
||||
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
@@ -946,13 +981,20 @@ def downsample_2d(
|
||||
kernel = kernel * gain
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
|
||||
tensor: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
up: int = 1,
|
||||
down: int = 1,
|
||||
pad: Tuple[int, int] = (0, 0),
|
||||
) -> torch.Tensor:
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
@@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
out_dim = out_dim or in_dim
|
||||
self.in_dim = in_dim
|
||||
@@ -1016,7 +1064,9 @@ class TemporalConvLayer(nn.Module):
|
||||
|
||||
# conv layers
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
|
||||
nn.GroupNorm(norm_num_groups, in_dim),
|
||||
nn.SiLU(),
|
||||
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.GroupNorm(norm_num_groups, out_dim),
|
||||
@@ -1058,3 +1108,295 @@ class TemporalConvLayer(nn.Module):
|
||||
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TemporalResnetBlock(nn.Module):
|
||||
r"""
|
||||
A Resnet block.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to be `None`):
|
||||
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
||||
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
||||
groups_out (`int`, *optional*, default to None):
|
||||
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
|
||||
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
||||
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
||||
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
||||
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
|
||||
"ada_group" for a stronger conditioning with scale and shift.
|
||||
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
|
||||
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
||||
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
||||
use_in_shortcut (`bool`, *optional*, default to `True`):
|
||||
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
|
||||
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
|
||||
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
|
||||
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
|
||||
`conv_shortcut` output.
|
||||
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
|
||||
If None, same as `out_channels`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
eps: float = 1e-6,
|
||||
non_linearity: str = "swish",
|
||||
kernel_size: Optional[torch.FloatTensor] = (3, 1, 1),
|
||||
output_scale_factor: float = 1.0,
|
||||
use_in_shortcut: Optional[bool] = None,
|
||||
conv_shortcut_bias: bool = True,
|
||||
conv_2d_out_channels: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
linear_cls = nn.Linear
|
||||
conv_cls = nn.Conv3d
|
||||
|
||||
padding = [k // 2 for k in kernel_size]
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.conv1 = conv_cls(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = linear_cls(temb_channels, out_channels)
|
||||
else:
|
||||
self.time_emb_proj = None
|
||||
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||
self.conv2 = conv_cls(
|
||||
out_channels,
|
||||
conv_2d_out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = conv_cls(
|
||||
in_channels,
|
||||
conv_2d_out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=conv_shortcut_bias,
|
||||
)
|
||||
|
||||
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
if self.time_emb_proj is not None:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = self.time_emb_proj(temb)[:, :, :, None, None]
|
||||
|
||||
if temb is not None:
|
||||
temb = temb.permute(0, 2, 1, 3, 4)
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
# VideoResBlock
|
||||
class SpatioTemporalResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
groups: int = 32,
|
||||
pre_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
temporal_eps: Optional[float] = None,
|
||||
non_linearity: str = "swish",
|
||||
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
|
||||
output_scale_factor: float = 1.0,
|
||||
kernel_size_3d: Optional[torch.FloatTensor] = (3, 1, 1),
|
||||
merge_factor: float = 0.5,
|
||||
merge_strategy="learned",
|
||||
switch_spatial_to_temporal_mix: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.spatial_res_block = ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=eps,
|
||||
groups=groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=time_embedding_norm,
|
||||
non_linearity=non_linearity,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=pre_norm,
|
||||
)
|
||||
|
||||
self.temporal_res_block = TemporalResnetBlock(
|
||||
in_channels=out_channels if out_channels is not None else in_channels,
|
||||
out_channels=out_channels if out_channels is not None else in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=temporal_eps if temporal_eps is not None else eps,
|
||||
groups=groups,
|
||||
dropout=dropout,
|
||||
non_linearity=non_linearity,
|
||||
output_scale_factor=output_scale_factor,
|
||||
kernel_size=kernel_size_3d,
|
||||
)
|
||||
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
num_frames: int = 1,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
hidden_states = self.spatial_res_block(hidden_states, temb, scale=scale)
|
||||
|
||||
batch_frames, channels, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
hidden_states_mix = (
|
||||
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
hidden_states = (
|
||||
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
|
||||
if temb is not None:
|
||||
temb = temb.reshape(batch_size, num_frames, -1)
|
||||
|
||||
hidden_states = self.temporal_res_block(hidden_states, temb)
|
||||
hidden_states = self.time_mixer(
|
||||
x_spatial=hidden_states_mix,
|
||||
x_temporal=hidden_states,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
switch_spatial_to_temporal_mix: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.merge_strategy = merge_strategy
|
||||
self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
|
||||
|
||||
assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"
|
||||
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
|
||||
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
|
||||
else:
|
||||
raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
|
||||
if self.merge_strategy == "fixed":
|
||||
alpha = self.mix_factor
|
||||
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
assert (
|
||||
image_only_indicator is not None
|
||||
), "Please provide image_only_indicator to use learned_with_images merge strategy"
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
torch.sigmoid(self.mix_factor)[..., None],
|
||||
)
|
||||
|
||||
# (batch, channel, frames, height, width)
|
||||
if ndims == 5:
|
||||
alpha = alpha[:, None, :, None, None]
|
||||
# (batch*frames, height*width, channels)
|
||||
elif ndims == 3:
|
||||
alpha = alpha.reshape(-1)[:, None, None]
|
||||
else:
|
||||
raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_spatial: torch.Tensor,
|
||||
x_temporal: torch.Tensor,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
|
||||
alpha = alpha.to(x_spatial.dtype)
|
||||
|
||||
if self.switch_spatial_to_temporal_mix:
|
||||
alpha = 1.0 - alpha
|
||||
|
||||
x = alpha * x_spatial + (1.0 - alpha) * x_temporal
|
||||
return x
|
||||
|
||||
@@ -19,8 +19,10 @@ from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock
|
||||
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .resnet import AlphaBlender
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -195,3 +197,229 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
|
||||
# VideoBlock
|
||||
class TransformerSpatioTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
||||
activation functions.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Configure if each `TransformerBlock` should contain two self-attention layers.
|
||||
positional_embeddings: (`str`, *optional*):
|
||||
The type of positional embeddings to apply to the sequence input before passing use.
|
||||
num_positional_embeddings: (`int`, *optional*):
|
||||
The maximum length of the sequence over which to apply positional embeddings.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: int = 320,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
merge_factor: float = 0.5,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
|
||||
linear_cls = nn.Linear
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps)
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_mix_inner_dim = inner_dim
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalBasicTransformerBlock(
|
||||
inner_dim,
|
||||
time_mix_inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_embed_dim = in_channels * 4
|
||||
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
||||
self.time_proj = Timesteps(in_channels, True, 0)
|
||||
self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
# TODO: should use out_channels for continuous projections
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: int,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
num_frames (`int`, *optional*, defaults to 1):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
assert (
|
||||
encoder_hidden_states.ndim == 3
|
||||
), f"n dims of spatial context should be 3 but are {encoder_hidden_states.ndim}"
|
||||
|
||||
# 1. Input
|
||||
batch_frames, channel, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
time_context = encoder_hidden_states
|
||||
time_context_first_timestep = time_context[::num_frames]
|
||||
time_context = time_context_first_timestep.repeat(height * width, 1, 1)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
||||
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
||||
num_frames_emb = num_frames_emb.reshape(-1)
|
||||
t_emb = self.time_proj(num_frames_emb)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
# 2. Blocks
|
||||
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
hidden_states_mix = hidden_states
|
||||
hidden_states_mix = hidden_states_mix + emb
|
||||
|
||||
hidden_states_mix = temporal_block(
|
||||
hidden_states_mix,
|
||||
num_frames=num_frames,
|
||||
encoder_hidden_states=time_context,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = self.time_mixer(
|
||||
x_spatial=hidden_states,
|
||||
x_temporal=hidden_states_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
859
src/diffusers/models/unet_spatio_temporal_condition.py
Normal file
859
src/diffusers/models/unet_spatio_temporal_condition.py
Normal file
@@ -0,0 +1,859 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNetSpatioTemporalConditionOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNetSpatioTemporalConditionModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class UNetSpatioTemporalConditionModel(
|
||||
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin
|
||||
):
|
||||
r"""
|
||||
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||
shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
||||
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
||||
The tuple of upsample blocks to use.
|
||||
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
||||
Whether to include self-attention in the basic transformer blocks, see
|
||||
[`~models.attention.BasicTransformerBlock`].
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
||||
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
||||
If `None`, normalization and activation layers is skipped in post-processing.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
||||
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
num_attention_heads (`int`, *optional*):
|
||||
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
||||
Dimension for the timestep embeddings.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
||||
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
||||
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
||||
An optional override for the dimension of the projected time embedding.
|
||||
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
||||
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
||||
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
||||
timestep_post_act (`str`, *optional*, defaults to `None`):
|
||||
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
||||
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of `cond_proj` layer in the timestep embedding.
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
|
||||
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
|
||||
*optional*): The dimension of the `class_labels` input when
|
||||
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
||||
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
||||
embeddings with the class embeddings.
|
||||
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
||||
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
||||
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
||||
otherwise.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 8,
|
||||
out_channels: int = 4,
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlockSpatioTemporal",
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
projection_class_embeddings_input_dim: int = 768,
|
||||
addition_time_embed_dim: int = 256,
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
conv_in_kernel: int = 3,
|
||||
conv_out_kernel: int = 3,
|
||||
kernel_size_3d: Optional[torch.FloatTensor] = (3, 1, 1),
|
||||
merge_factor: float = 0.5,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
if num_attention_heads is not None:
|
||||
raise ValueError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
||||
down_block_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
|
||||
down_block_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
|
||||
down_block_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
|
||||
down_block_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding,
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos, downscale_freq_shift=0
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.add_time_proj = Timesteps(
|
||||
addition_time_embed_dim, flip_sin_to_cos, downscale_freq_shift=0
|
||||
)
|
||||
self.add_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(layers_per_block, int):
|
||||
layers_per_block = [layers_per_block] * len(down_block_types)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
||||
down_block_types
|
||||
)
|
||||
|
||||
blocks_time_embed_dim = time_embed_dim
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block[i],
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim[i],
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
downsample_padding=1,
|
||||
dropout=dropout,
|
||||
kernel_size_3d=kernel_size_3d,
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockSpatioTemporal(
|
||||
block_out_channels[-1],
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dropout=dropout,
|
||||
kernel_size_3d=kernel_size_3d,
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = list(
|
||||
reversed(transformer_layers_per_block)
|
||||
)
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[
|
||||
min(i + 1, len(block_out_channels) - 1)
|
||||
]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
add_upsample = True
|
||||
self.num_upsamplers += 1
|
||||
else:
|
||||
add_upsample = False
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=reversed_layers_per_block[i] + 1,
|
||||
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resolution_idx=i,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=reversed_cross_attention_dim[i],
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
dropout=dropout,
|
||||
kernel_size_3d=kernel_size_3d,
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = get_activation(act_fn)
|
||||
|
||||
conv_out_padding = (conv_out_kernel - 1) // 2
|
||||
self.conv_out = nn.Conv2d(
|
||||
block_out_channels[0],
|
||||
out_channels,
|
||||
kernel_size=conv_out_kernel,
|
||||
padding=conv_out_padding,
|
||||
)
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(
|
||||
name: str,
|
||||
module: torch.nn.Module,
|
||||
processors: Dict[str, AttentionProcessor],
|
||||
):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(
|
||||
return_deprecated_lora=True
|
||||
)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(
|
||||
self,
|
||||
processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
|
||||
_remove_lora=False,
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(
|
||||
processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
|
||||
)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(
|
||||
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
|
||||
for proc in self.attn_processors.values()
|
||||
):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(
|
||||
proc.__class__ in CROSS_ATTENTION_PROCESSORS
|
||||
for proc in self.attn_processors.values()
|
||||
):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = (
|
||||
num_sliceable_layers * [slice_size]
|
||||
if not isinstance(slice_size, list)
|
||||
else slice_size
|
||||
)
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(
|
||||
module: torch.nn.Module, slice_size: List[int]
|
||||
):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
setattr(upsample_block, "s1", s1)
|
||||
setattr(upsample_block, "s2", s2)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
for k in freeu_keys:
|
||||
if (
|
||||
hasattr(upsample_block, k)
|
||||
or getattr(upsample_block, k, None) is not None
|
||||
):
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
added_time_ids: torch.Tensor,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DConditionModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
added_cond_kwargs: (`dict`):
|
||||
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
a `tuple` is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
for dim in sample.shape[-2:]:
|
||||
if dim % default_overall_up_factor != 0:
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
break
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = (
|
||||
1 - encoder_attention_mask.to(sample.dtype)
|
||||
) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
batch_size, num_frames = sample.shape[:2]
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((batch_size, -1))
|
||||
time_embeds = time_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(time_embeds)
|
||||
emb = emb + aug_emb
|
||||
|
||||
# Flatten the batch and frames dimensions
|
||||
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
||||
sample = sample.flatten(0, 1)
|
||||
# Repeat the embeddings num_video_frames times
|
||||
# emb: [batch, channels] -> [batch * frames, channels]
|
||||
emb = emb.repeat_interleave(num_frames, dim=0)
|
||||
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
||||
num_frames, dim=0
|
||||
)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
lora_scale = (
|
||||
cross_attention_kwargs.get("scale", 1.0)
|
||||
if cross_attention_kwargs is not None
|
||||
else 1.0
|
||||
)
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
image_only_indicator = torch.zeros(
|
||||
batch_size, num_frames, dtype=sample.dtype, device=sample.device
|
||||
)
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if (
|
||||
hasattr(downsample_block, "has_cross_attention")
|
||||
and downsample_block.has_cross_attention
|
||||
):
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
num_video_frames=num_frames,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
scale=lora_scale,
|
||||
num_video_frames=num_frames,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if (
|
||||
hasattr(self.mid_block, "has_cross_attention")
|
||||
and self.mid_block.has_cross_attention
|
||||
):
|
||||
sample = self.mid_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
num_video_frames=num_frames,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
temb=emb,
|
||||
num_video_frames=num_frames,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[
|
||||
: -len(upsample_block.resnets)
|
||||
]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if (
|
||||
hasattr(upsample_block, "has_cross_attention")
|
||||
and upsample_block.has_cross_attention
|
||||
):
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
num_video_frames=num_frames,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
scale=lora_scale,
|
||||
num_video_frames=num_frames,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
# 7. Reshape back to original shape
|
||||
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNetSpatioTemporalConditionOutput(sample=sample)
|
||||
@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
|
||||
from .unet_2d_blocks import (
|
||||
AutoencoderTinyBlock,
|
||||
UNetMidBlock2D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -122,11 +127,15 @@ class Encoder(nn.Module):
|
||||
)
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||
self.conv_out = nn.Conv2d(
|
||||
block_out_channels[-1], conv_out_channels, 3, padding=1
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@@ -155,9 +164,13 @@ class Encoder(nn.Module):
|
||||
)
|
||||
else:
|
||||
for down_block in self.down_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), sample
|
||||
)
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample
|
||||
)
|
||||
|
||||
else:
|
||||
# down
|
||||
@@ -267,14 +280,18 @@ class Decoder(nn.Module):
|
||||
if norm_type == "spatial":
|
||||
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
||||
else:
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
|
||||
@@ -292,14 +309,20 @@ class Decoder(nn.Module):
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
|
||||
create_custom_forward(self.mid_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
@@ -310,7 +333,9 @@ class Decoder(nn.Module):
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), sample, latent_embeds
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
@@ -350,7 +375,9 @@ class UpSample(nn.Module):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
||||
self.deconv = nn.ConvTranspose2d(
|
||||
in_channels, out_channels, kernel_size=4, stride=2, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `UpSample` class."""
|
||||
@@ -394,9 +421,13 @@ class MaskConditionEncoder(nn.Module):
|
||||
for l in range(len(out_channels)):
|
||||
out_ch_ = out_channels[l]
|
||||
if l == 0 or l == 1:
|
||||
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
|
||||
layers.append(
|
||||
nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
else:
|
||||
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
|
||||
layers.append(
|
||||
nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)
|
||||
)
|
||||
in_ch_ = out_ch_
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
@@ -511,7 +542,9 @@ class MaskConditionDecoder(nn.Module):
|
||||
if norm_type == "spatial":
|
||||
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
||||
else:
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
@@ -540,7 +573,10 @@ class MaskConditionDecoder(nn.Module):
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
|
||||
create_custom_forward(self.mid_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
@@ -548,17 +584,25 @@ class MaskConditionDecoder(nn.Module):
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
|
||||
create_custom_forward(self.condition_encoder),
|
||||
masked_image,
|
||||
mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
mask_ = nn.functional.interpolate(
|
||||
mask, size=sample.shape[-2:], mode="nearest"
|
||||
)
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
@@ -573,16 +617,22 @@ class MaskConditionDecoder(nn.Module):
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.condition_encoder), masked_image, mask
|
||||
create_custom_forward(self.condition_encoder),
|
||||
masked_image,
|
||||
mask,
|
||||
)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
mask_ = nn.functional.interpolate(
|
||||
mask, size=sample.shape[-2:], mode="nearest"
|
||||
)
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), sample, latent_embeds
|
||||
)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
else:
|
||||
@@ -599,7 +649,9 @@ class MaskConditionDecoder(nn.Module):
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
mask_ = nn.functional.interpolate(
|
||||
mask, size=sample.shape[-2:], mode="nearest"
|
||||
)
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = up_block(sample, latent_embeds)
|
||||
if image is not None and mask is not None:
|
||||
@@ -671,7 +723,9 @@ class VectorQuantizer(nn.Module):
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||
device=new.device
|
||||
)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
@@ -686,13 +740,17 @@ class VectorQuantizer(nn.Module):
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
|
||||
def forward(
|
||||
self, z: torch.FloatTensor
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.vq_embed_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
|
||||
min_encoding_indices = torch.argmin(
|
||||
torch.cdist(z_flattened, self.embedding.weight), dim=1
|
||||
)
|
||||
|
||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||
perplexity = None
|
||||
@@ -700,9 +758,13 @@ class VectorQuantizer(nn.Module):
|
||||
|
||||
# compute loss for embedding
|
||||
if not self.legacy:
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
else:
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
|
||||
# preserve gradients
|
||||
z_q: torch.FloatTensor = z + (z_q - z).detach()
|
||||
@@ -711,16 +773,22 @@ class VectorQuantizer(nn.Module):
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
if self.remap is not None:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z.shape[0], -1
|
||||
) # add batch axis
|
||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
||||
)
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
|
||||
def get_codebook_entry(
|
||||
self, indices: torch.LongTensor, shape: Tuple[int, ...]
|
||||
) -> torch.FloatTensor:
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
@@ -754,7 +822,10 @@ class DiagonalGaussianDistribution(object):
|
||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
||||
# make sure sample is on the same device as the parameters and has same dtype
|
||||
sample = randn_tensor(
|
||||
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
|
||||
self.mean.shape,
|
||||
generator=generator,
|
||||
device=self.parameters.device,
|
||||
dtype=self.parameters.dtype,
|
||||
)
|
||||
x = self.mean + self.std * sample
|
||||
return x
|
||||
@@ -764,7 +835,10 @@ class DiagonalGaussianDistribution(object):
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
@@ -775,11 +849,16 @@ class DiagonalGaussianDistribution(object):
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
||||
def nll(
|
||||
self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
|
||||
) -> torch.Tensor:
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims,
|
||||
)
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
return self.mean
|
||||
@@ -818,14 +897,27 @@ class EncoderTiny(nn.Module):
|
||||
num_channels = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
|
||||
layers.append(
|
||||
nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
else:
|
||||
layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
|
||||
layers.append(
|
||||
nn.Conv2d(
|
||||
num_channels,
|
||||
num_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=2,
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
|
||||
for _ in range(num_block):
|
||||
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
||||
|
||||
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
|
||||
layers.append(
|
||||
nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.gradient_checkpointing = False
|
||||
@@ -841,9 +933,13 @@ class EncoderTiny(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.layers), x, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.layers), x
|
||||
)
|
||||
|
||||
else:
|
||||
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
||||
@@ -899,7 +995,15 @@ class DecoderTiny(nn.Module):
|
||||
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
||||
|
||||
conv_out_channel = num_channels if not is_final_block else out_channels
|
||||
layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
|
||||
layers.append(
|
||||
nn.Conv2d(
|
||||
num_channels,
|
||||
conv_out_channel,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=is_final_block,
|
||||
)
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.gradient_checkpointing = False
|
||||
@@ -918,9 +1022,13 @@ class DecoderTiny(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.layers), x, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.layers), x
|
||||
)
|
||||
|
||||
else:
|
||||
x = self.layers(x)
|
||||
|
||||
@@ -145,6 +145,7 @@ else:
|
||||
"StableDiffusionPix2PixZeroPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableDiffusionVideoPipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
]
|
||||
@@ -372,6 +373,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionVideoPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
|
||||
@@ -47,6 +47,7 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_video"] = ["StableDiffusionVideoPipeline"]
|
||||
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
|
||||
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
|
||||
_import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"]
|
||||
@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
|
||||
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
||||
from .pipeline_stable_diffusion_video import StableDiffusionVideoPipeline
|
||||
from .pipeline_stable_unclip import StableUnCLIPPipeline
|
||||
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -0,0 +1,588 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableDiffusionVideoPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for zero-shot text-to-video pipeline.
|
||||
|
||||
Args:
|
||||
frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
|
||||
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
||||
num_channels)`.
|
||||
"""
|
||||
|
||||
frames: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
class StableDiffusionVideoPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline to generate video from an input image using Stable Video Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
|
||||
unet ([`UNetSpatioTemporalConditionModel`]):
|
||||
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "image_encoder->unet->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLTemporalDecoder,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
unet: UNetSpatioTemporalConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
image_encoder=image_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def _encode_image(
|
||||
self, image, device, num_videos_per_prompt, do_classifier_free_guidance
|
||||
):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(
|
||||
images=image, return_tensors="pt"
|
||||
).pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeddings = self.image_encoder(image).image_embeds
|
||||
image_embeddings = image_embeddings.unsqueeze(1)
|
||||
|
||||
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = image_embeddings.shape
|
||||
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
||||
image_embeddings = image_embeddings.view(
|
||||
bs_embed * num_videos_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
||||
|
||||
return image_embeddings
|
||||
|
||||
def _encode_vae_image(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
image = image.to(device=device)
|
||||
image_latents = self.vae.encode(image).latent_dist.mode()
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_latents = torch.zeros_like(image_latents)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_latents = torch.cat([negative_image_latents, image_latents])
|
||||
|
||||
# duplicate image_latents for each generation per prompt, using mps friendly method
|
||||
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
||||
|
||||
return image_latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self,
|
||||
fps_id,
|
||||
motion_bucket_id,
|
||||
cond_aug,
|
||||
dtype,
|
||||
batch_size,
|
||||
num_videos_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
add_time_ids = [fps_id, motion_bucket_id, cond_aug]
|
||||
|
||||
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(
|
||||
add_time_ids
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids])
|
||||
|
||||
return add_time_ids
|
||||
|
||||
def decode_latents(self, latents, num_frames, decoding_t=14):
|
||||
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
||||
latents = latents.flatten(0, 1)
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
# decode decoding_t frames at a time to avoid OOM
|
||||
frames = []
|
||||
for i in range(0, latents.shape[0], decoding_t):
|
||||
num_frames_in = latents[i : i + decoding_t].shape[0]
|
||||
frame = self.vae.decode(latents[i : i + decoding_t], num_frames_in).sample
|
||||
frames.append(frame)
|
||||
frames = torch.cat(frames, dim=0)
|
||||
|
||||
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
||||
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(
|
||||
0, 2, 1, 3, 4
|
||||
)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
frames = frames.float()
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, image, height, width, callback_steps):
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
||||
f" {type(image)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None
|
||||
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_frames,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_frames,
|
||||
num_channels_latents // 2,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(
|
||||
shape, generator=generator, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
||||
height: int = 576,
|
||||
width: int = 1024,
|
||||
num_frames: int = 14,
|
||||
num_inference_steps: int = 50,
|
||||
min_guidance_scale: float = 1.0,
|
||||
max_guidance_scale: float = 2.5,
|
||||
fps_id: int = 6,
|
||||
motion_bucket_id: int = 127,
|
||||
cond_aug: int = 0.02,
|
||||
decoding_t: int = 14,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
||||
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
||||
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, *optional*, defaults to 14):
|
||||
The number of video frames to generate.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference. This parameter is modulated by `strength`.
|
||||
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
||||
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
||||
max_guidance_scale (`float`, *optional*, defaults to 2.5):
|
||||
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
||||
fps_id (`int`, *optional*, defaults to 6):
|
||||
The frame rate ID. Used as conditioning for the generation.
|
||||
motion_bucket_id (`int`, *optional*, defaults to 127):
|
||||
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
||||
cond_aug (`int`, *optional*, defaults to 0.02):
|
||||
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
||||
decoding_t (`int`, *optional*, defaults to 14):
|
||||
The number of frames to decode at a time. This is used to avoid OOM errors.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionVideoPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionVideoPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionVideoPipeline
|
||||
from diffusers.utils import load_image, export_to_video
|
||||
|
||||
pipe = StableDiffusionVideoPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
frames = pipe(image, num_frames=14, decoding_t=8).frames[0]
|
||||
export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(image, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(image, list):
|
||||
batch_size = len(image)
|
||||
else:
|
||||
batch_size = image.shape[0]
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = max_guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input image
|
||||
image_embeddings = self._encode_image(
|
||||
image, device, num_videos_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
|
||||
# 4. Encode input image using VAE
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
image = image + cond_aug * torch.randn_like(image)
|
||||
|
||||
needs_upcasting = (
|
||||
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
)
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
image_latents = self._encode_vae_image(
|
||||
image, device, num_videos_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
image_latents = image_latents.to(image_embeddings.dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
# Repeat the image latents for each frame so we can concatenate them with the noise
|
||||
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
||||
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
||||
|
||||
# 5. Get Added Time IDs
|
||||
added_time_ids = self._get_add_time_ids(
|
||||
fps_id,
|
||||
motion_bucket_id,
|
||||
cond_aug,
|
||||
image_embeddings.dtype,
|
||||
batch_size,
|
||||
num_videos_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
added_time_ids = added_time_ids.to(device)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_frames,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
image_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare guidance scale
|
||||
guidance_scale = torch.linspace(
|
||||
min_guidance_scale, max_guidance_scale, num_frames
|
||||
).unsqueeze(0)
|
||||
guidance_scale = guidance_scale.to(device, latents.dtype)
|
||||
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
||||
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
|
||||
# Concatenate image_latents over channels dimention
|
||||
latent_model_input = torch.cat(
|
||||
[latent_model_input, image_latents], dim=2
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=image_embeddings,
|
||||
added_time_ids=added_time_ids,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_cond - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
frames = self.decode_latents(latents, num_frames, decoding_t)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
if not output_type == "latent":
|
||||
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
||||
else:
|
||||
frames = latents
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return frames
|
||||
|
||||
return StableDiffusionVideoPipelineOutput(frames=frames)
|
||||
@@ -323,8 +323,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
|
||||
@@ -358,8 +358,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
|
||||
@@ -358,8 +358,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
|
||||
@@ -357,8 +357,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
|
||||
@@ -144,7 +144,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
timestep_type: str = "discrete", # can be "discrete" or "continuous"
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -268,6 +271,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
# when timestep_type is continuous, we need to convert the timesteps to continuous values using c_noise
|
||||
if self.config.timestep_type == "continuous":
|
||||
timesteps = np.array([self.get_scalings(sigma)[-1] for sigma in sigmas])
|
||||
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
@@ -301,8 +308,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
@@ -311,6 +330,33 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def get_scalings(self, sigma=None):
|
||||
"""
|
||||
Get the scalings for the current timestep.
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
The scaling factors for the current timestep.
|
||||
"""
|
||||
if sigma is None:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
c_skip = torch.ones_like(sigma, device=sigma.device)
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = sigma.clone()
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = 0.25 * math.log(sigma)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
@@ -391,6 +437,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
c_skip, c_out, _, _ = self.get_scalings()
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
@@ -412,8 +459,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
pred_original_sample = model_output * c_out + sample * c_skip
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
|
||||
@@ -303,8 +303,20 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
|
||||
@@ -324,8 +324,20 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
|
||||
@@ -335,8 +335,20 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
|
||||
@@ -337,8 +337,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
|
||||
@@ -32,6 +32,21 @@ class AutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderTiny(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -272,6 +287,21 @@ class UNetMotionModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNetSpatioTemporalConditionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1022,6 +1022,21 @@ class StableDiffusionUpscalePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLAdapterPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import struct
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -115,7 +115,9 @@ def export_to_obj(mesh, output_obj_path: str = None):
|
||||
f.writelines("\n".join(combined_data))
|
||||
|
||||
|
||||
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
|
||||
def export_to_video(
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||
) -> str:
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
else:
|
||||
@@ -123,9 +125,12 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
||||
|
||||
if isinstance(video_frames[0], PIL.Image.Image):
|
||||
video_frames = [np.array(frame) for frame in video_frames]
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
h, w, c = video_frames[0].shape
|
||||
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
|
||||
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
|
||||
for i in range(len(video_frames)):
|
||||
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
|
||||
video_writer.write(img)
|
||||
|
||||
354
tests/models/test_models_unet_spatiotemporal.py
Normal file
354
tests/models/test_models_unet_spatiotemporal.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNetSpatioTemporalConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetSpatioTemporalConditionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def fps_id(self):
|
||||
return 6
|
||||
|
||||
@property
|
||||
def motion_bucket_id(self):
|
||||
return 127
|
||||
|
||||
@property
|
||||
def cond_aug(self):
|
||||
return 0.02
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
),
|
||||
"up_block_types": (
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
),
|
||||
"cross_attention_dim": 32,
|
||||
"attention_head_dim": 8,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["cross_attention_dim"] = (32, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_simple_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
init_dict["class_embed_type"] = "simple_projection"
|
||||
init_dict["projection_class_embeddings_input_dim"] = sample_size
|
||||
|
||||
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
init_dict["class_embed_type"] = "simple_projection"
|
||||
init_dict["projection_class_embeddings_input_dim"] = sample_size
|
||||
init_dict["class_embeddings_concat"] = True
|
||||
|
||||
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.set_attention_slice("auto")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice("max")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice(2)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
def test_model_sliceable_head_dim(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
def check_sliceable_dim_attr(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
assert isinstance(module.sliceable_head_dim, int)
|
||||
|
||||
for child in module.children():
|
||||
check_sliceable_dim_attr(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in model.children():
|
||||
check_sliceable_dim_attr(module)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
EXPECTED_SET = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
|
||||
def test_pickle(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
@@ -23,6 +23,7 @@ from parameterized import parameterized
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
StableDiffusionPipeline,
|
||||
@@ -195,10 +196,16 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
self.assertTrue(
|
||||
torch_all_close(
|
||||
param.grad.data, named_params_2[name].grad.data, atol=5e-5
|
||||
)
|
||||
)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
model, loading_info = AutoencoderKL.from_pretrained(
|
||||
"fusing/autoencoder-kl-dummy", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
@@ -248,17 +255,39 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
)
|
||||
elif torch_device == "cpu":
|
||||
expected_output_slice = torch.tensor(
|
||||
[-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
|
||||
[
|
||||
-0.1352,
|
||||
0.0878,
|
||||
0.0419,
|
||||
-0.0818,
|
||||
-0.1069,
|
||||
0.0688,
|
||||
-0.1458,
|
||||
-0.4446,
|
||||
-0.0026,
|
||||
]
|
||||
)
|
||||
else:
|
||||
expected_output_slice = torch.tensor(
|
||||
[-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
|
||||
[
|
||||
-0.2421,
|
||||
0.4642,
|
||||
0.2507,
|
||||
-0.0438,
|
||||
0.0682,
|
||||
0.3160,
|
||||
-0.2018,
|
||||
-0.0727,
|
||||
0.2485,
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AsymmetricAutoencoderKLTests(
|
||||
ModelTesterMixin, UNetTesterMixin, unittest.TestCase
|
||||
):
|
||||
model_class = AsymmetricAutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -336,7 +365,9 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
generator = torch.Generator("cpu")
|
||||
if seed is not None:
|
||||
generator.manual_seed(0)
|
||||
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
|
||||
image = randn_tensor(
|
||||
(4, 3, 32, 32), generator=generator, device=torch.device(torch_device)
|
||||
)
|
||||
|
||||
return {"sample": image, "generator": generator}
|
||||
|
||||
@@ -364,6 +395,98 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
...
|
||||
|
||||
|
||||
class AutoncoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLTemporalDecoder
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 3
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
num_frames = 3
|
||||
|
||||
return {"sample": image, "num_frames": num_frames}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [8, 16],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 4,
|
||||
"layers_per_block": 2,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
|
||||
self.assertTrue(
|
||||
torch_all_close(
|
||||
param.grad.data, named_params_2[name].grad.data, atol=5e-5
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@@ -377,10 +500,16 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
image = (
|
||||
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
|
||||
.to(torch_device)
|
||||
.to(dtype)
|
||||
)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
|
||||
def get_sd_vae_model(
|
||||
self, model_id="hf-internal-testing/taesd-diffusers", fp16=False
|
||||
):
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
|
||||
@@ -414,7 +543,9 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
|
||||
expected_output_slice = torch.tensor(
|
||||
[0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382]
|
||||
)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@@ -454,7 +585,11 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
image = (
|
||||
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
|
||||
.to(torch_device)
|
||||
.to(dtype)
|
||||
)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
|
||||
@@ -503,7 +638,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
expected_output_slice = torch.tensor(
|
||||
expected_slice_mps if torch_device == "mps" else expected_slice
|
||||
)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@@ -557,7 +694,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
expected_output_slice = torch.tensor(
|
||||
expected_slice_mps if torch_device == "mps" else expected_slice
|
||||
)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@@ -609,7 +748,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(13,), (16,), (27,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||
@@ -627,7 +769,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
@@ -660,7 +805,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
dist = model.encode(image).latent_dist
|
||||
sample = dist.sample(generator=generator)
|
||||
|
||||
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
|
||||
assert list(sample.shape) == [image.shape[0], 4] + [
|
||||
i // 8 for i in image.shape[2:]
|
||||
]
|
||||
|
||||
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
@@ -701,10 +848,16 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
image = (
|
||||
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
|
||||
.to(torch_device)
|
||||
.to(dtype)
|
||||
)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
|
||||
def get_sd_vae_model(
|
||||
self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False
|
||||
):
|
||||
revision = "main"
|
||||
torch_dtype = torch.float32
|
||||
|
||||
@@ -749,7 +902,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
expected_output_slice = torch.tensor(
|
||||
expected_slice_mps if torch_device == "mps" else expected_slice
|
||||
)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@@ -779,7 +934,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
expected_output_slice = torch.tensor(
|
||||
expected_slice_mps if torch_device == "mps" else expected_slice
|
||||
)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@@ -808,7 +965,10 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
@@ -841,7 +1001,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
dist = model.encode(image).latent_dist
|
||||
sample = dist.sample(generator=generator)
|
||||
|
||||
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
|
||||
assert list(sample.shape) == [image.shape[0], 4] + [
|
||||
i // 8 for i in image.shape[2:]
|
||||
]
|
||||
|
||||
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
@@ -860,37 +1022,52 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_encode_decode(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
|
||||
vae = ConsistencyDecoderVAE.from_pretrained(
|
||||
"openai/consistency-decoder"
|
||||
) # TODO - update
|
||||
vae.to(torch_device)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
).resize((256, 256))
|
||||
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
|
||||
None, :, :, :
|
||||
].cuda()
|
||||
image = torch.from_numpy(
|
||||
np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1
|
||||
)[None, :, :, :].cuda()
|
||||
|
||||
latent = vae.encode(image).latent_dist.mean
|
||||
|
||||
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
|
||||
sample = vae.decode(
|
||||
latent, generator=torch.Generator("cpu").manual_seed(0)
|
||||
).sample
|
||||
|
||||
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
|
||||
expected_output = torch.tensor(
|
||||
[-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024]
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_sd(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
|
||||
vae = ConsistencyDecoderVAE.from_pretrained(
|
||||
"openai/consistency-decoder"
|
||||
) # TODO - update
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
out = pipe(
|
||||
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
|
||||
"horse",
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
actual_output = out[:2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
|
||||
expected_output = torch.tensor(
|
||||
[0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759]
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
@@ -905,18 +1082,23 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
).resize((256, 256))
|
||||
image = (
|
||||
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
|
||||
torch.from_numpy(
|
||||
np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1
|
||||
)[None, :, :, :]
|
||||
.half()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
latent = vae.encode(image).latent_dist.mean
|
||||
|
||||
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
|
||||
sample = vae.decode(
|
||||
latent, generator=torch.Generator("cpu").manual_seed(0)
|
||||
).sample
|
||||
|
||||
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor(
|
||||
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
|
||||
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
@@ -926,17 +1108,24 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
"openai/consistency-decoder", torch_dtype=torch.float16
|
||||
) # TODO - update
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
vae=vae,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
out = pipe(
|
||||
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
|
||||
"horse",
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
actual_output = out[:2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor(
|
||||
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
|
||||
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
261
tests/pipelines/stable_diffusion/test_stable_diffusion_video.py
Normal file
261
tests/pipelines/stable_diffusion/test_stable_diffusion_video.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKLTemporalDecoder,
|
||||
DDIMScheduler,
|
||||
StableDiffusionVideoPipeline,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
)
|
||||
from diffusers.utils import load_image, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class StableDiffusionVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionVideoPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback",
|
||||
"callback_steps",
|
||||
]
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNetSpatioTemporalConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=(
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
),
|
||||
up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"),
|
||||
cross_attention_dim=32,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="linear",
|
||||
clip_sample=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLTemporalDecoder(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"feature_extractor": feature_extractor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"image": image,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@unittest.skip("Attention slicing is not enabled in this pipeline")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for components in pipe.components.values():
|
||||
if hasattr(components, "set_default_attn_processor"):
|
||||
components.set_default_attn_processor()
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
output = pipe(**inputs)
|
||||
output_batch = pipe(**batched_inputs)
|
||||
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
|
||||
def test_to_device(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to("cuda")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
self.assertTrue(all(device == "cuda" for device in model_devices))
|
||||
|
||||
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
|
||||
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(torch_dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionVideoPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_sd_video(self):
|
||||
pipe = StableDiffusionVideoPipeline.from_pretrained("diffusers/svd-test")
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
|
||||
)
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
num_frames = 3
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
num_frames=num_frames,
|
||||
generator=generator,
|
||||
num_inference_steps=3,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.frames[0]
|
||||
assert image.shape == (num_frames, 576, 1024, 3)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.8592, 0.8645, 0.8499, 0.8722, 0.8769, 0.8421, 0.8557, 0.8528, 0.8285])
|
||||
assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3
|
||||
Reference in New Issue
Block a user