mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 06:54:32 +08:00
Compare commits
114 Commits
add-sharde
...
test-v-upd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56e8fca572 | ||
|
|
c5941a26a4 | ||
|
|
8bc42512fe | ||
|
|
55b4d09080 | ||
|
|
c452d9c042 | ||
|
|
ee9f7d2493 | ||
|
|
8620851aa0 | ||
|
|
90d8e832f8 | ||
|
|
18930e0b85 | ||
|
|
847bd0a479 | ||
|
|
3178b16b17 | ||
|
|
a08ef009d1 | ||
|
|
804bdebe51 | ||
|
|
a193e49dff | ||
|
|
c9d1727613 | ||
|
|
82cf60828f | ||
|
|
26ed460265 | ||
|
|
403a81c30d | ||
|
|
1b3cf2db5e | ||
|
|
b8d84c4320 | ||
|
|
3fbe123d84 | ||
|
|
f7cf8c338c | ||
|
|
ab8076f234 | ||
|
|
7b6a0d48c6 | ||
|
|
6adae54046 | ||
|
|
af85fb1bc1 | ||
|
|
760333d524 | ||
|
|
f651c12ef8 | ||
|
|
d614a33a09 | ||
|
|
13b646edd3 | ||
|
|
cb49cbdd29 | ||
|
|
1ce8ff51e6 | ||
|
|
fdd182f335 | ||
|
|
2a46326c25 | ||
|
|
e34e9d9a33 | ||
|
|
96af28f92b | ||
|
|
6827a1dc6a | ||
|
|
c3bdeb8a4c | ||
|
|
cf70b9a0b4 | ||
|
|
712b9950c5 | ||
|
|
21148de853 | ||
|
|
d930977656 | ||
|
|
268ffea0e7 | ||
|
|
8bcf43d52a | ||
|
|
b071aaa719 | ||
|
|
5316fb5107 | ||
|
|
9af07d1d5c | ||
|
|
d0017d9b70 | ||
|
|
0cf6c6b291 | ||
|
|
df986274d6 | ||
|
|
7ddd14bd94 | ||
|
|
4346ddd402 | ||
|
|
9da55b381c | ||
|
|
4d4469ee87 | ||
|
|
f9954a0e7b | ||
|
|
e7798333c4 | ||
|
|
efb1e5e1d8 | ||
|
|
beaaf18b2c | ||
|
|
132fe97bf4 | ||
|
|
2f35e8c94c | ||
|
|
b336529573 | ||
|
|
3e47d3c8ed | ||
|
|
122a6bd390 | ||
|
|
37c428a79c | ||
|
|
eefed8ab6b | ||
|
|
05eaec2d39 | ||
|
|
e68424378f | ||
|
|
24b5c4360c | ||
|
|
0c4192b537 | ||
|
|
dff26ce8af | ||
|
|
9f22651c1f | ||
|
|
d8c9e67aac | ||
|
|
6c28367b1a | ||
|
|
f9def2aeed | ||
|
|
576fa1c7dc | ||
|
|
f1457b7e1d | ||
|
|
1f34311eec | ||
|
|
f976f5a31e | ||
|
|
8e1851a16a | ||
|
|
6c69c7a0d2 | ||
|
|
6481e9495f | ||
|
|
8c3fd58c85 | ||
|
|
9117547ee0 | ||
|
|
af1e86af8d | ||
|
|
29551f8e30 | ||
|
|
661033171b | ||
|
|
20efe541c5 | ||
|
|
5a523e21c6 | ||
|
|
b0fc4fd4cb | ||
|
|
678d19fa18 | ||
|
|
c8ec445964 | ||
|
|
ffd9e26a65 | ||
|
|
6f87490408 | ||
|
|
9c9d46763b | ||
|
|
47684dab43 | ||
|
|
5218f46173 | ||
|
|
8ee280773f | ||
|
|
85846f7450 | ||
|
|
28dee6e735 | ||
|
|
165ed7c5d5 | ||
|
|
d4cdfa33f5 | ||
|
|
1bd09b1489 | ||
|
|
edf7121ec7 | ||
|
|
7b64d3a17b | ||
|
|
c93606c93c | ||
|
|
5df09ef355 | ||
|
|
ac9473153c | ||
|
|
ee9d7b8ecd | ||
|
|
669824e5bb | ||
|
|
45c9b56bf7 | ||
|
|
cad51d45d1 | ||
|
|
7de5d7c6fd | ||
|
|
58883ee085 | ||
|
|
2f5648177e |
730
scripts/convert_svd_to_diffusers.py
Normal file
730
scripts/convert_svd_to_diffusers.py
Normal file
@@ -0,0 +1,730 @@
|
|||||||
|
from diffusers.utils import is_accelerate_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
||||||
|
"""
|
||||||
|
Creates a config for the diffusers based on the config of the LDM model.
|
||||||
|
"""
|
||||||
|
if controlnet:
|
||||||
|
unet_params = original_config.model.params.control_stage_config.params
|
||||||
|
else:
|
||||||
|
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
||||||
|
unet_params = original_config.model.params.unet_config.params
|
||||||
|
else:
|
||||||
|
unet_params = original_config.model.params.network_config.params
|
||||||
|
|
||||||
|
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
|
||||||
|
|
||||||
|
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||||
|
|
||||||
|
down_block_types = []
|
||||||
|
resolution = 1
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
block_type = (
|
||||||
|
"CrossAttnDownBlockSpatioTemporal"
|
||||||
|
if resolution in unet_params.attention_resolutions
|
||||||
|
else "DownBlockSpatioTemporal"
|
||||||
|
)
|
||||||
|
down_block_types.append(block_type)
|
||||||
|
if i != len(block_out_channels) - 1:
|
||||||
|
resolution *= 2
|
||||||
|
|
||||||
|
up_block_types = []
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
block_type = (
|
||||||
|
"CrossAttnUpBlockSpatioTemporal"
|
||||||
|
if resolution in unet_params.attention_resolutions
|
||||||
|
else "UpBlockSpatioTemporal"
|
||||||
|
)
|
||||||
|
up_block_types.append(block_type)
|
||||||
|
resolution //= 2
|
||||||
|
|
||||||
|
if unet_params.transformer_depth is not None:
|
||||||
|
transformer_layers_per_block = (
|
||||||
|
unet_params.transformer_depth
|
||||||
|
if isinstance(unet_params.transformer_depth, int)
|
||||||
|
else list(unet_params.transformer_depth)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transformer_layers_per_block = 1
|
||||||
|
|
||||||
|
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||||
|
|
||||||
|
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||||
|
use_linear_projection = (
|
||||||
|
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||||
|
)
|
||||||
|
if use_linear_projection:
|
||||||
|
# stable diffusion 2-base-512 and 2-768
|
||||||
|
if head_dim is None:
|
||||||
|
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||||
|
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||||
|
|
||||||
|
class_embed_type = None
|
||||||
|
addition_embed_type = None
|
||||||
|
addition_time_embed_dim = None
|
||||||
|
projection_class_embeddings_input_dim = None
|
||||||
|
context_dim = None
|
||||||
|
|
||||||
|
if unet_params.context_dim is not None:
|
||||||
|
context_dim = (
|
||||||
|
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
if "num_classes" in unet_params:
|
||||||
|
if unet_params.num_classes == "sequential":
|
||||||
|
addition_time_embed_dim = 256
|
||||||
|
assert "adm_in_channels" in unet_params
|
||||||
|
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"sample_size": image_size // vae_scale_factor,
|
||||||
|
"in_channels": unet_params.in_channels,
|
||||||
|
"down_block_types": tuple(down_block_types),
|
||||||
|
"block_out_channels": tuple(block_out_channels),
|
||||||
|
"layers_per_block": unet_params.num_res_blocks,
|
||||||
|
"cross_attention_dim": context_dim,
|
||||||
|
"attention_head_dim": head_dim,
|
||||||
|
"use_linear_projection": use_linear_projection,
|
||||||
|
"class_embed_type": class_embed_type,
|
||||||
|
"addition_embed_type": addition_embed_type,
|
||||||
|
"addition_time_embed_dim": addition_time_embed_dim,
|
||||||
|
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
||||||
|
"transformer_layers_per_block": transformer_layers_per_block,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "disable_self_attentions" in unet_params:
|
||||||
|
config["only_cross_attention"] = unet_params.disable_self_attentions
|
||||||
|
|
||||||
|
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
|
||||||
|
config["num_class_embeds"] = unet_params.num_classes
|
||||||
|
|
||||||
|
if controlnet:
|
||||||
|
config["conditioning_channels"] = unet_params.hint_channels
|
||||||
|
else:
|
||||||
|
config["out_channels"] = unet_params.out_channels
|
||||||
|
config["up_block_types"] = tuple(up_block_types)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def assign_to_checkpoint(
|
||||||
|
paths,
|
||||||
|
checkpoint,
|
||||||
|
old_checkpoint,
|
||||||
|
attention_paths_to_split=None,
|
||||||
|
additional_replacements=None,
|
||||||
|
config=None,
|
||||||
|
mid_block_suffix="",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
||||||
|
attention layers, and takes into account additional replacements that may arise.
|
||||||
|
|
||||||
|
Assigns the weights to the new checkpoint.
|
||||||
|
"""
|
||||||
|
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||||
|
|
||||||
|
# Splits the attention layers into three variables.
|
||||||
|
if attention_paths_to_split is not None:
|
||||||
|
for path, path_map in attention_paths_to_split.items():
|
||||||
|
old_tensor = old_checkpoint[path]
|
||||||
|
channels = old_tensor.shape[0] // 3
|
||||||
|
|
||||||
|
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||||
|
|
||||||
|
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||||
|
|
||||||
|
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||||
|
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||||
|
|
||||||
|
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||||
|
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||||
|
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||||
|
|
||||||
|
if mid_block_suffix is not None:
|
||||||
|
mid_block_suffix = f".{mid_block_suffix}"
|
||||||
|
else:
|
||||||
|
mid_block_suffix = ""
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
new_path = path["new"]
|
||||||
|
|
||||||
|
# These have already been assigned
|
||||||
|
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Global renaming happens here
|
||||||
|
new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}")
|
||||||
|
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||||
|
new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}")
|
||||||
|
|
||||||
|
if additional_replacements is not None:
|
||||||
|
for replacement in additional_replacements:
|
||||||
|
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||||
|
|
||||||
|
if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight":
|
||||||
|
print("yeyy")
|
||||||
|
|
||||||
|
# proj_attn.weight has to be converted from conv 1D to linear
|
||||||
|
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
|
||||||
|
shape = old_checkpoint[path["old"]].shape
|
||||||
|
if is_attn_weight and len(shape) == 3:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||||
|
elif is_attn_weight and len(shape) == 4:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
||||||
|
else:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||||
|
|
||||||
|
|
||||||
|
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||||
|
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||||
|
|
||||||
|
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||||
|
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||||
|
|
||||||
|
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
new_item = new_item.replace("time_stack", "temporal_transformer_blocks")
|
||||||
|
|
||||||
|
new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias")
|
||||||
|
new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight")
|
||||||
|
new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias")
|
||||||
|
new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight")
|
||||||
|
|
||||||
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
|
"""
|
||||||
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
|
"""
|
||||||
|
if n_shave_prefix_segments >= 0:
|
||||||
|
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||||
|
else:
|
||||||
|
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||||
|
|
||||||
|
|
||||||
|
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item.replace("in_layers.0", "norm1")
|
||||||
|
new_item = new_item.replace("in_layers.2", "conv1")
|
||||||
|
|
||||||
|
new_item = new_item.replace("out_layers.0", "norm2")
|
||||||
|
new_item = new_item.replace("out_layers.3", "conv2")
|
||||||
|
|
||||||
|
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
||||||
|
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||||
|
|
||||||
|
new_item = new_item.replace("time_stack.", "")
|
||||||
|
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_unet_checkpoint(
|
||||||
|
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Takes a state dict and a config, and returns a converted checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if skip_extract_state_dict:
|
||||||
|
unet_state_dict = checkpoint
|
||||||
|
else:
|
||||||
|
# extract state_dict for UNet
|
||||||
|
unet_state_dict = {}
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
|
unet_key = "model.diffusion_model."
|
||||||
|
|
||||||
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
|
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||||
|
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
|
logger.warning(
|
||||||
|
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||||
|
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||||
|
)
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith("model.diffusion_model"):
|
||||||
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||||
|
else:
|
||||||
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
|
logger.warning(
|
||||||
|
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||||
|
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith(unet_key):
|
||||||
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||||
|
|
||||||
|
new_checkpoint = {}
|
||||||
|
|
||||||
|
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
||||||
|
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||||
|
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
||||||
|
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||||
|
|
||||||
|
if config["class_embed_type"] is None:
|
||||||
|
# No parameters to port
|
||||||
|
...
|
||||||
|
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
||||||
|
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
||||||
|
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||||
|
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
||||||
|
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
||||||
|
|
||||||
|
# if config["addition_embed_type"] == "text_time":
|
||||||
|
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
||||||
|
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||||
|
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
||||||
|
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||||
|
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||||
|
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
||||||
|
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
||||||
|
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
||||||
|
|
||||||
|
# Retrieves the keys for the input blocks only
|
||||||
|
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||||
|
input_blocks = {
|
||||||
|
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_input_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retrieves the keys for the middle blocks only
|
||||||
|
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||||
|
middle_blocks = {
|
||||||
|
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_middle_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retrieves the keys for the output blocks only
|
||||||
|
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||||
|
output_blocks = {
|
||||||
|
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_output_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(1, num_input_blocks):
|
||||||
|
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||||
|
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||||
|
|
||||||
|
spatial_resnets = [
|
||||||
|
key
|
||||||
|
for key in input_blocks[i]
|
||||||
|
if f"input_blocks.{i}.0" in key
|
||||||
|
and (
|
||||||
|
f"input_blocks.{i}.0.op" not in key
|
||||||
|
and f"input_blocks.{i}.0.time_stack" not in key
|
||||||
|
and f"input_blocks.{i}.0.time_mixer" not in key
|
||||||
|
)
|
||||||
|
]
|
||||||
|
temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key]
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||||
|
|
||||||
|
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
||||||
|
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
||||||
|
f"input_blocks.{i}.0.op.weight"
|
||||||
|
)
|
||||||
|
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
||||||
|
f"input_blocks.{i}.0.op.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
paths = renew_resnet_paths(spatial_resnets)
|
||||||
|
meta_path = {
|
||||||
|
"old": f"input_blocks.{i}.0",
|
||||||
|
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
|
||||||
|
}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
paths = renew_resnet_paths(temporal_resnets)
|
||||||
|
meta_path = {
|
||||||
|
"old": f"input_blocks.{i}.0",
|
||||||
|
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
|
||||||
|
}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO resnet time_mixer.mix_factor
|
||||||
|
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||||
|
new_checkpoint[
|
||||||
|
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||||
|
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||||
|
|
||||||
|
if len(attentions):
|
||||||
|
paths = renew_attention_paths(attentions)
|
||||||
|
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
resnet_0 = middle_blocks[0]
|
||||||
|
attentions = middle_blocks[1]
|
||||||
|
resnet_1 = middle_blocks[2]
|
||||||
|
|
||||||
|
resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key]
|
||||||
|
resnet_0_paths = renew_resnet_paths(resnet_0_spatial)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
assign_to_checkpoint(
|
||||||
|
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
|
||||||
|
)
|
||||||
|
|
||||||
|
resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key]
|
||||||
|
resnet_0_paths = renew_resnet_paths(resnet_0_temporal)
|
||||||
|
assign_to_checkpoint(
|
||||||
|
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
|
||||||
|
)
|
||||||
|
|
||||||
|
resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key]
|
||||||
|
resnet_1_paths = renew_resnet_paths(resnet_1_spatial)
|
||||||
|
assign_to_checkpoint(
|
||||||
|
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
|
||||||
|
)
|
||||||
|
|
||||||
|
resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key]
|
||||||
|
resnet_1_paths = renew_resnet_paths(resnet_1_temporal)
|
||||||
|
assign_to_checkpoint(
|
||||||
|
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
|
||||||
|
"middle_block.0.time_mixer.mix_factor"
|
||||||
|
]
|
||||||
|
new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[
|
||||||
|
"middle_block.2.time_mixer.mix_factor"
|
||||||
|
]
|
||||||
|
|
||||||
|
attentions_paths = renew_attention_paths(attentions)
|
||||||
|
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(num_output_blocks):
|
||||||
|
block_id = i // (config["layers_per_block"] + 1)
|
||||||
|
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
||||||
|
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||||
|
output_block_list = {}
|
||||||
|
|
||||||
|
for layer in output_block_layers:
|
||||||
|
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
||||||
|
if layer_id in output_block_list:
|
||||||
|
output_block_list[layer_id].append(layer_name)
|
||||||
|
else:
|
||||||
|
output_block_list[layer_id] = [layer_name]
|
||||||
|
|
||||||
|
if len(output_block_list) > 1:
|
||||||
|
spatial_resnets = [
|
||||||
|
key
|
||||||
|
for key in output_blocks[i]
|
||||||
|
if f"output_blocks.{i}.0" in key
|
||||||
|
and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key)
|
||||||
|
]
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key]
|
||||||
|
|
||||||
|
paths = renew_resnet_paths(spatial_resnets)
|
||||||
|
meta_path = {
|
||||||
|
"old": f"output_blocks.{i}.0",
|
||||||
|
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
|
||||||
|
}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
paths = renew_resnet_paths(temporal_resnets)
|
||||||
|
meta_path = {
|
||||||
|
"old": f"output_blocks.{i}.0",
|
||||||
|
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
|
||||||
|
}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||||
|
new_checkpoint[
|
||||||
|
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||||
|
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||||
|
|
||||||
|
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||||
|
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||||
|
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||||
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||||
|
f"output_blocks.{i}.{index}.conv.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||||
|
f"output_blocks.{i}.{index}.conv.bias"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Clear attentions as they have been attributed above.
|
||||||
|
if len(attentions) == 2:
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key]
|
||||||
|
if len(attentions):
|
||||||
|
paths = renew_attention_paths(attentions)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
meta_path = {
|
||||||
|
"old": f"output_blocks.{i}.1",
|
||||||
|
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||||
|
}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
spatial_layers = [
|
||||||
|
layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer
|
||||||
|
]
|
||||||
|
resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
for path in resnet_0_paths:
|
||||||
|
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||||
|
new_path = ".".join(
|
||||||
|
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||||
|
|
||||||
|
temporal_layers = [
|
||||||
|
layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key
|
||||||
|
]
|
||||||
|
resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
for path in resnet_0_paths:
|
||||||
|
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||||
|
new_path = ".".join(
|
||||||
|
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||||
|
|
||||||
|
new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
|
||||||
|
f"output_blocks.{str(i)}.0.time_mixer.mix_factor"
|
||||||
|
]
|
||||||
|
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def conv_attn_to_linear(checkpoint):
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
||||||
|
for key in keys:
|
||||||
|
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||||
|
if checkpoint[key].ndim > 2:
|
||||||
|
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||||
|
elif "proj_attn.weight" in key:
|
||||||
|
if checkpoint[key].ndim > 2:
|
||||||
|
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False):
|
||||||
|
"""
|
||||||
|
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
# Temporal resnet
|
||||||
|
new_item = old_item.replace("in_layers.0", "norm1")
|
||||||
|
new_item = new_item.replace("in_layers.2", "conv1")
|
||||||
|
|
||||||
|
new_item = new_item.replace("out_layers.0", "norm2")
|
||||||
|
new_item = new_item.replace("out_layers.3", "conv2")
|
||||||
|
|
||||||
|
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||||
|
|
||||||
|
new_item = new_item.replace("time_stack.", "temporal_res_block.")
|
||||||
|
|
||||||
|
# Spatial resnet
|
||||||
|
new_item = new_item.replace("conv1", "spatial_res_block.conv1")
|
||||||
|
new_item = new_item.replace("norm1", "spatial_res_block.norm1")
|
||||||
|
|
||||||
|
new_item = new_item.replace("conv2", "spatial_res_block.conv2")
|
||||||
|
new_item = new_item.replace("norm2", "spatial_res_block.norm2")
|
||||||
|
|
||||||
|
new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut")
|
||||||
|
|
||||||
|
new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor")
|
||||||
|
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||||
|
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||||
|
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||||
|
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||||
|
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||||
|
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||||
|
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||||
|
# extract state dict for VAE
|
||||||
|
vae_state_dict = {}
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith(vae_key):
|
||||||
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
|
|
||||||
|
new_checkpoint = {}
|
||||||
|
|
||||||
|
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||||
|
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||||
|
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||||
|
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
||||||
|
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
||||||
|
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
||||||
|
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
||||||
|
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
||||||
|
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
||||||
|
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
||||||
|
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
||||||
|
new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"]
|
||||||
|
new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"]
|
||||||
|
|
||||||
|
# new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
||||||
|
# new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
||||||
|
# new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
||||||
|
# new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||||
|
|
||||||
|
# Retrieves the keys for the encoder down blocks only
|
||||||
|
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||||
|
down_blocks = {
|
||||||
|
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retrieves the keys for the decoder up blocks only
|
||||||
|
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||||
|
up_blocks = {
|
||||||
|
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||||
|
|
||||||
|
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||||
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||||
|
f"encoder.down.{i}.downsample.conv.weight"
|
||||||
|
)
|
||||||
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
||||||
|
f"encoder.down.{i}.downsample.conv.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||||
|
num_mid_res_blocks = 2
|
||||||
|
for i in range(1, num_mid_res_blocks + 1):
|
||||||
|
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||||
|
paths = renew_vae_attention_paths(mid_attentions)
|
||||||
|
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
conv_attn_to_linear(new_checkpoint)
|
||||||
|
|
||||||
|
for i in range(num_up_blocks):
|
||||||
|
block_id = num_up_blocks - 1 - i
|
||||||
|
|
||||||
|
resnets = [
|
||||||
|
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||||
|
]
|
||||||
|
|
||||||
|
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||||
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||||
|
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||||
|
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||||
|
]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||||
|
num_mid_res_blocks = 2
|
||||||
|
for i in range(1, num_mid_res_blocks + 1):
|
||||||
|
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||||
|
paths = renew_vae_attention_paths(mid_attentions)
|
||||||
|
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
conv_attn_to_linear(new_checkpoint)
|
||||||
|
return new_checkpoint
|
||||||
@@ -76,6 +76,7 @@ else:
|
|||||||
[
|
[
|
||||||
"AsymmetricAutoencoderKL",
|
"AsymmetricAutoencoderKL",
|
||||||
"AutoencoderKL",
|
"AutoencoderKL",
|
||||||
|
"AutoencoderKLTemporalDecoder",
|
||||||
"AutoencoderTiny",
|
"AutoencoderTiny",
|
||||||
"ConsistencyDecoderVAE",
|
"ConsistencyDecoderVAE",
|
||||||
"ControlNetModel",
|
"ControlNetModel",
|
||||||
@@ -92,6 +93,7 @@ else:
|
|||||||
"UNet2DModel",
|
"UNet2DModel",
|
||||||
"UNet3DConditionModel",
|
"UNet3DConditionModel",
|
||||||
"UNetMotionModel",
|
"UNetMotionModel",
|
||||||
|
"UNetSpatioTemporalConditionModel",
|
||||||
"VQModel",
|
"VQModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -267,6 +269,7 @@ else:
|
|||||||
"StableDiffusionPix2PixZeroPipeline",
|
"StableDiffusionPix2PixZeroPipeline",
|
||||||
"StableDiffusionSAGPipeline",
|
"StableDiffusionSAGPipeline",
|
||||||
"StableDiffusionUpscalePipeline",
|
"StableDiffusionUpscalePipeline",
|
||||||
|
"StableDiffusionVideoPipeline",
|
||||||
"StableDiffusionXLAdapterPipeline",
|
"StableDiffusionXLAdapterPipeline",
|
||||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||||
"StableDiffusionXLControlNetInpaintPipeline",
|
"StableDiffusionXLControlNetInpaintPipeline",
|
||||||
@@ -446,6 +449,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .models import (
|
from .models import (
|
||||||
AsymmetricAutoencoderKL,
|
AsymmetricAutoencoderKL,
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
AutoencoderKLTemporalDecoder,
|
||||||
AutoencoderTiny,
|
AutoencoderTiny,
|
||||||
ConsistencyDecoderVAE,
|
ConsistencyDecoderVAE,
|
||||||
ControlNetModel,
|
ControlNetModel,
|
||||||
@@ -462,6 +466,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
UNet2DModel,
|
UNet2DModel,
|
||||||
UNet3DConditionModel,
|
UNet3DConditionModel,
|
||||||
UNetMotionModel,
|
UNetMotionModel,
|
||||||
|
UNetSpatioTemporalConditionModel,
|
||||||
VQModel,
|
VQModel,
|
||||||
)
|
)
|
||||||
from .optimization import (
|
from .optimization import (
|
||||||
@@ -616,6 +621,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionPix2PixZeroPipeline,
|
StableDiffusionPix2PixZeroPipeline,
|
||||||
StableDiffusionSAGPipeline,
|
StableDiffusionSAGPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
|
StableDiffusionVideoPipeline,
|
||||||
StableDiffusionXLAdapterPipeline,
|
StableDiffusionXLAdapterPipeline,
|
||||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||||
StableDiffusionXLControlNetInpaintPipeline,
|
StableDiffusionXLControlNetInpaintPipeline,
|
||||||
|
|||||||
@@ -14,7 +14,12 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
|
from ..utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
_LazyModule,
|
||||||
|
is_flax_available,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {}
|
_import_structure = {}
|
||||||
@@ -23,6 +28,7 @@ if is_torch_available():
|
|||||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||||
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||||
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
|
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
|
||||||
|
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||||
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
|
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||||
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||||
@@ -38,6 +44,7 @@ if is_torch_available():
|
|||||||
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||||
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
|
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
|
||||||
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||||
|
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||||
_import_structure["vq_model"] = ["VQModel"]
|
_import_structure["vq_model"] = ["VQModel"]
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
@@ -51,6 +58,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .adapter import MultiAdapter, T2IAdapter
|
from .adapter import MultiAdapter, T2IAdapter
|
||||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||||
from .autoencoder_kl import AutoencoderKL
|
from .autoencoder_kl import AutoencoderKL
|
||||||
|
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||||
from .autoencoder_tiny import AutoencoderTiny
|
from .autoencoder_tiny import AutoencoderTiny
|
||||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||||
from .controlnet import ControlNetModel
|
from .controlnet import ControlNetModel
|
||||||
@@ -66,6 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .unet_3d_condition import UNet3DConditionModel
|
from .unet_3d_condition import UNet3DConditionModel
|
||||||
from .unet_kandi3 import Kandinsky3UNet
|
from .unet_kandi3 import Kandinsky3UNet
|
||||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||||
|
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||||
from .vq_model import VQModel
|
from .vq_model import VQModel
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
|||||||
@@ -194,7 +194,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
if not self.use_ada_layer_norm_single:
|
if not self.use_ada_layer_norm_single:
|
||||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||||
|
|
||||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
final_dropout=final_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Fuser
|
# 4. Fuser
|
||||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||||
@@ -339,6 +344,181 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_allow_in_graph
|
||||||
|
class TemporalBasicTransformerBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
A basic Transformer block for video like data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
dim (`int`): The number of channels in the input and output.
|
||||||
|
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`): The number of channels in each head.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||||
|
attention_bias (:
|
||||||
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||||
|
upcast_attention (`bool`, *optional*):
|
||||||
|
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||||
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use learnable elementwise affine parameters for normalization.
|
||||||
|
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||||
|
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||||
|
final_dropout (`bool` *optional*, defaults to False):
|
||||||
|
Whether to apply a final dropout after the last feed-forward layer.
|
||||||
|
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||||
|
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
time_mix_inner_dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
dropout=0.0,
|
||||||
|
cross_attention_dim: Optional[int] = None,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
final_dropout: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.is_res = dim == time_mix_inner_dim
|
||||||
|
|
||||||
|
self.norm_in = nn.LayerNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
# Define 3 blocks. Each block has its own normalization layer.
|
||||||
|
# 1. Self-Attn
|
||||||
|
self.norm_in = nn.LayerNorm(dim, eps=norm_eps)
|
||||||
|
self.ff_in = FeedForward(
|
||||||
|
dim,
|
||||||
|
dim_out=time_mix_inner_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn="geglu",
|
||||||
|
final_dropout=final_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=time_mix_inner_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Cross-Attn
|
||||||
|
if cross_attention_dim is not None:
|
||||||
|
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||||
|
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||||
|
# the second cross attention block.
|
||||||
|
self.norm2 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
|
||||||
|
self.attn2 = Attention(
|
||||||
|
query_dim=time_mix_inner_dim,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
) # is self-attn if encoder_hidden_states is none
|
||||||
|
else:
|
||||||
|
self.norm2 = None
|
||||||
|
self.attn2 = None
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
self.norm3 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
|
||||||
|
self.ff = FeedForward(
|
||||||
|
time_mix_inner_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn="geglu",
|
||||||
|
final_dropout=final_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# let chunk size default to None
|
||||||
|
self._chunk_size = None
|
||||||
|
self._chunk_dim = 0
|
||||||
|
|
||||||
|
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
||||||
|
# Sets chunk feed-forward
|
||||||
|
self._chunk_size = chunk_size
|
||||||
|
self._chunk_dim = dim
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
num_frames: int,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
cross_attention_kwargs: Dict[str, Any] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||||
|
# 0. Self-Attention
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
batch_frames, seq_length, channels = hidden_states.shape
|
||||||
|
batch_size = batch_frames // num_frames
|
||||||
|
|
||||||
|
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||||
|
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm_in(hidden_states)
|
||||||
|
hidden_states = self.ff_in(hidden_states)
|
||||||
|
if self.is_res:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||||
|
|
||||||
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
|
attn_output = self.attn1(
|
||||||
|
norm_hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
**cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
|
# 3. Cross-Attention
|
||||||
|
if self.attn2 is not None:
|
||||||
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
|
attn_output = self.attn2(
|
||||||
|
norm_hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=encoder_attention_mask,
|
||||||
|
**cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
|
# 4. Feed-forward
|
||||||
|
norm_hidden_states = self.norm3(hidden_states)
|
||||||
|
|
||||||
|
if self._chunk_size is not None:
|
||||||
|
# "feed_forward_chunk_size" can be used to save memory
|
||||||
|
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||||
|
)
|
||||||
|
|
||||||
|
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||||
|
ff_output = torch.cat(
|
||||||
|
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
||||||
|
dim=self._chunk_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
if self.is_res:
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states = ff_output
|
||||||
|
|
||||||
|
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||||
|
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
A feed-forward layer.
|
A feed-forward layer.
|
||||||
|
|||||||
672
src/diffusers/models/autoencoder_kl_temporal_decoder.py
Normal file
672
src/diffusers/models/autoencoder_kl_temporal_decoder.py
Normal file
@@ -0,0 +1,672 @@
|
|||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..loaders import FromOriginalVAEMixin
|
||||||
|
from ..utils import BaseOutput, is_torch_version
|
||||||
|
from ..utils.accelerate_utils import apply_forward_hook
|
||||||
|
from .attention_processor import (
|
||||||
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
|
AttentionProcessor,
|
||||||
|
AttnAddedKVProcessor,
|
||||||
|
AttnProcessor,
|
||||||
|
)
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
|
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||||
|
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 4,
|
||||||
|
out_channels: int = 3,
|
||||||
|
block_out_channels: Tuple[int, ...] = (
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
),
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
norm_type: str = "group", # group, spatial
|
||||||
|
alpha: float = 0.0,
|
||||||
|
merge_strategy: str = "learned",
|
||||||
|
conv_out_kernel_size=(3, 1, 1),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers_per_block = layers_per_block
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||||
|
temb_channels = in_channels if norm_type == "spatial" else None
|
||||||
|
self.mid_block = MidBlockTemporalDecoder(
|
||||||
|
num_layers=self.layers_per_block,
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
attention_head_dim=block_out_channels[-1],
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
temporal_resnet_eps=1e-5,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
resnet_time_scale_shift=norm_type,
|
||||||
|
merge_factor=alpha,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# up
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
up_block = UpBlockTemporalDecoder(
|
||||||
|
num_layers=self.layers_per_block + 1,
|
||||||
|
in_channels=prev_output_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
add_upsample=not is_final_block,
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
temporal_resnet_eps=1e-5,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
attention_head_dim=output_channel,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
resnet_time_scale_shift=norm_type,
|
||||||
|
merge_factor=alpha,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||||
|
|
||||||
|
if isinstance(conv_out_kernel_size, Iterable):
|
||||||
|
padding = [int(k // 2) for k in conv_out_kernel_size]
|
||||||
|
else:
|
||||||
|
padding = int(conv_out_kernel_size // 2)
|
||||||
|
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = torch.nn.Conv2d(
|
||||||
|
in_channels=block_out_channels[0],
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.time_conv_out = torch.nn.Conv3d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=conv_out_kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
image_only_indicator: torch.FloatTensor,
|
||||||
|
num_frames: int = 1,
|
||||||
|
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
r"""The forward method of the `Decoder` class."""
|
||||||
|
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if is_torch_version(">=", "1.11.0"):
|
||||||
|
# middle
|
||||||
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.mid_block),
|
||||||
|
sample,
|
||||||
|
image_only_indicator,
|
||||||
|
latent_embeds,
|
||||||
|
num_frames,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
sample = sample.to(upscale_dtype)
|
||||||
|
|
||||||
|
# up
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(up_block),
|
||||||
|
sample,
|
||||||
|
image_only_indicator,
|
||||||
|
latent_embeds,
|
||||||
|
num_frames,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# middle
|
||||||
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.mid_block),
|
||||||
|
sample,
|
||||||
|
image_only_indicator,
|
||||||
|
latent_embeds,
|
||||||
|
num_frames,
|
||||||
|
)
|
||||||
|
sample = sample.to(upscale_dtype)
|
||||||
|
|
||||||
|
# up
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(up_block),
|
||||||
|
sample,
|
||||||
|
image_only_indicator,
|
||||||
|
latent_embeds,
|
||||||
|
num_frames,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# middle
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
temb=latent_embeds,
|
||||||
|
num_frames=num_frames,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
sample = sample.to(upscale_dtype)
|
||||||
|
|
||||||
|
# up
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
sample = up_block(
|
||||||
|
sample,
|
||||||
|
temb=latent_embeds,
|
||||||
|
num_frames=num_frames,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# post-process
|
||||||
|
if latent_embeds is None:
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
else:
|
||||||
|
sample = self.conv_norm_out(sample, latent_embeds)
|
||||||
|
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
batch_frames, channels, height, width = sample.shape
|
||||||
|
batch_size = batch_frames // num_frames
|
||||||
|
sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||||
|
sample = self.time_conv_out(sample)
|
||||||
|
|
||||||
|
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoencoderKLOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output of AutoencoderKL encoding method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent_dist (`DiagonalGaussianDistribution`):
|
||||||
|
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
||||||
|
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent_dist: "DiagonalGaussianDistribution"
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||||
|
r"""
|
||||||
|
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||||
|
|
||||||
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||||
|
for all models (such as downloading or saving).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||||
|
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||||
|
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||||
|
Tuple of downsample block types.
|
||||||
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||||
|
Tuple of upsample block types.
|
||||||
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||||
|
Tuple of block output channels.
|
||||||
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||||
|
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||||
|
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||||
|
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||||
|
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||||
|
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||||
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||||
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||||
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||||
|
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||||
|
force_upcast (`bool`, *optional*, default to `True`):
|
||||||
|
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
||||||
|
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
||||||
|
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 3,
|
||||||
|
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||||
|
block_out_channels: Tuple[int] = (64,),
|
||||||
|
layers_per_block: int = 1,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
latent_channels: int = 4,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
sample_size: int = 32,
|
||||||
|
scaling_factor: float = 0.18215,
|
||||||
|
force_upcast: float = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# pass init params to Encoder
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=latent_channels,
|
||||||
|
down_block_types=down_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
act_fn=act_fn,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
double_z=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pass init params to Decoder
|
||||||
|
self.decoder = TemporalDecoder(
|
||||||
|
in_channels=latent_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||||
|
|
||||||
|
self.use_slicing = False
|
||||||
|
self.use_tiling = False
|
||||||
|
|
||||||
|
# only relevant if vae tiling is enabled
|
||||||
|
self.tile_sample_min_size = self.config.sample_size
|
||||||
|
sample_size = (
|
||||||
|
self.config.sample_size[0]
|
||||||
|
if isinstance(self.config.sample_size, (list, tuple))
|
||||||
|
else self.config.sample_size
|
||||||
|
)
|
||||||
|
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||||
|
self.tile_overlap_factor = 0.25
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, (Encoder, TemporalDecoder)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def enable_tiling(self, use_tiling: bool = True):
|
||||||
|
r"""
|
||||||
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||||
|
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||||
|
processing larger images.
|
||||||
|
"""
|
||||||
|
self.use_tiling = use_tiling
|
||||||
|
|
||||||
|
def disable_tiling(self):
|
||||||
|
r"""
|
||||||
|
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||||
|
decoding in one step.
|
||||||
|
"""
|
||||||
|
self.enable_tiling(False)
|
||||||
|
|
||||||
|
def enable_slicing(self):
|
||||||
|
r"""
|
||||||
|
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||||
|
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||||
|
"""
|
||||||
|
self.use_slicing = True
|
||||||
|
|
||||||
|
def disable_slicing(self):
|
||||||
|
r"""
|
||||||
|
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||||
|
decoding in one step.
|
||||||
|
"""
|
||||||
|
self.use_slicing = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(
|
||||||
|
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||||
|
def set_default_attn_processor(self):
|
||||||
|
"""
|
||||||
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
|
"""
|
||||||
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||||
|
processor = AttnAddedKVProcessor()
|
||||||
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||||
|
processor = AttnProcessor()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_attn_processor(processor, _remove_lora=True)
|
||||||
|
|
||||||
|
@apply_forward_hook
|
||||||
|
def encode(
|
||||||
|
self, x: torch.FloatTensor, return_dict: bool = True
|
||||||
|
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||||
|
"""
|
||||||
|
Encode a batch of images into latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.FloatTensor`): Input batch of images.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The latent representations of the encoded images. If `return_dict` is True, a
|
||||||
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||||
|
"""
|
||||||
|
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||||
|
return self.tiled_encode(x, return_dict=return_dict)
|
||||||
|
|
||||||
|
if self.use_slicing and x.shape[0] > 1:
|
||||||
|
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
||||||
|
h = torch.cat(encoded_slices)
|
||||||
|
else:
|
||||||
|
h = self.encoder(x)
|
||||||
|
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (posterior,)
|
||||||
|
|
||||||
|
return AutoencoderKLOutput(latent_dist=posterior)
|
||||||
|
|
||||||
|
def _decode(
|
||||||
|
self, z: torch.FloatTensor, num_frames: int, return_dict: bool = True
|
||||||
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||||
|
return self.tiled_decode(z, return_dict=return_dict)
|
||||||
|
|
||||||
|
batch_size = z.shape[0] // num_frames
|
||||||
|
# TODO: dont hardcode this
|
||||||
|
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
|
||||||
|
dec = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
|
@apply_forward_hook
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
z: torch.FloatTensor,
|
||||||
|
num_frames: int,
|
||||||
|
return_dict: bool = True,
|
||||||
|
generator=None,
|
||||||
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
"""
|
||||||
|
Decode a batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z (`torch.FloatTensor`): Input batch of latent vectors.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||||
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||||
|
returned.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.use_slicing and z.shape[0] > 1:
|
||||||
|
decoded_slices = [self._decode(z_slice, num_frames // 2).sample for z_slice in z.split(1)]
|
||||||
|
decoded = torch.cat(decoded_slices)
|
||||||
|
else:
|
||||||
|
decoded = self._decode(z, num_frames).sample
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (decoded,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=decoded)
|
||||||
|
|
||||||
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||||
|
for y in range(blend_extent):
|
||||||
|
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||||
|
for x in range(blend_extent):
|
||||||
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||||
|
r"""Encode a batch of images using a tiled encoder.
|
||||||
|
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||||
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||||
|
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||||
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||||
|
output, but they should be much less noticeable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.FloatTensor`): Input batch of images.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
||||||
|
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||||
|
`tuple` is returned.
|
||||||
|
"""
|
||||||
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||||
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||||
|
row_limit = self.tile_latent_min_size - blend_extent
|
||||||
|
|
||||||
|
# Split the image into 512x512 tiles and encode them separately.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, x.shape[2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, x.shape[3], overlap_size):
|
||||||
|
tile = x[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
i : i + self.tile_sample_min_size,
|
||||||
|
j : j + self.tile_sample_min_size,
|
||||||
|
]
|
||||||
|
tile = self.encoder(tile)
|
||||||
|
tile = self.quant_conv(tile)
|
||||||
|
row.append(tile)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
moments = torch.cat(result_rows, dim=2)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (posterior,)
|
||||||
|
|
||||||
|
return AutoencoderKLOutput(latent_dist=posterior)
|
||||||
|
|
||||||
|
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
r"""
|
||||||
|
Decode a batch of images using a tiled decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z (`torch.FloatTensor`): Input batch of latent vectors.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||||
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||||
|
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||||
|
row_limit = self.tile_sample_min_size - blend_extent
|
||||||
|
|
||||||
|
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||||
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, z.shape[2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, z.shape[3], overlap_size):
|
||||||
|
tile = z[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
i : i + self.tile_latent_min_size,
|
||||||
|
j : j + self.tile_latent_min_size,
|
||||||
|
]
|
||||||
|
tile = self.post_quant_conv(tile)
|
||||||
|
decoded = self.decoder(tile)
|
||||||
|
row.append(decoded)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
dec = torch.cat(result_rows, dim=2)
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
sample_posterior: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
num_frames: int = 1,
|
||||||
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`): Input sample.
|
||||||
|
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to sample from the posterior.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
x = sample
|
||||||
|
posterior = self.encode(x).latent_dist
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample(generator=generator)
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
|
||||||
|
dec = self.decode(z, num_frames=num_frames).sample
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
||||||
@@ -165,7 +165,10 @@ class Upsample2D(nn.Module):
|
|||||||
self.Conv2d_0 = conv
|
self.Conv2d_0 = conv
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
output_size: Optional[int] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
assert hidden_states.shape[1] == self.channels
|
assert hidden_states.shape[1] == self.channels
|
||||||
|
|
||||||
@@ -379,7 +382,11 @@ class FirUpsample2D(nn.Module):
|
|||||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||||
|
|
||||||
inverse_conv = F.conv_transpose2d(
|
inverse_conv = F.conv_transpose2d(
|
||||||
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
|
hidden_states,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
output_padding=output_padding,
|
||||||
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = upfirdn2d_native(
|
output = upfirdn2d_native(
|
||||||
@@ -530,7 +537,14 @@ class KDownsample2D(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
weight = inputs.new_zeros(
|
||||||
|
[
|
||||||
|
inputs.shape[1],
|
||||||
|
inputs.shape[1],
|
||||||
|
self.kernel.shape[0],
|
||||||
|
self.kernel.shape[1],
|
||||||
|
]
|
||||||
|
)
|
||||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||||
weight[indices, indices] = kernel
|
weight[indices, indices] = kernel
|
||||||
@@ -553,7 +567,14 @@ class KUpsample2D(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
weight = inputs.new_zeros(
|
||||||
|
[
|
||||||
|
inputs.shape[1],
|
||||||
|
inputs.shape[1],
|
||||||
|
self.kernel.shape[0],
|
||||||
|
self.kernel.shape[1],
|
||||||
|
]
|
||||||
|
)
|
||||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||||
weight[indices, indices] = kernel
|
weight[indices, indices] = kernel
|
||||||
@@ -690,11 +711,19 @@ class ResnetBlock2D(nn.Module):
|
|||||||
self.conv_shortcut = None
|
self.conv_shortcut = None
|
||||||
if self.use_in_shortcut:
|
if self.use_in_shortcut:
|
||||||
self.conv_shortcut = conv_cls(
|
self.conv_shortcut = conv_cls(
|
||||||
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
in_channels,
|
||||||
|
conv_2d_out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=conv_shortcut_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
|
self,
|
||||||
|
input_tensor: torch.FloatTensor,
|
||||||
|
temb: torch.FloatTensor,
|
||||||
|
scale: float = 1.0,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
hidden_states = input_tensor
|
hidden_states = input_tensor
|
||||||
|
|
||||||
@@ -866,7 +895,10 @@ class ResidualTemporalBlock1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def upsample_2d(
|
def upsample_2d(
|
||||||
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
hidden_states: torch.FloatTensor,
|
||||||
|
kernel: Optional[torch.FloatTensor] = None,
|
||||||
|
factor: int = 2,
|
||||||
|
gain: float = 1,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""Upsample2D a batch of 2D images with the given filter.
|
r"""Upsample2D a batch of 2D images with the given filter.
|
||||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||||
@@ -910,7 +942,10 @@ def upsample_2d(
|
|||||||
|
|
||||||
|
|
||||||
def downsample_2d(
|
def downsample_2d(
|
||||||
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
hidden_states: torch.FloatTensor,
|
||||||
|
kernel: Optional[torch.FloatTensor] = None,
|
||||||
|
factor: int = 2,
|
||||||
|
gain: float = 1,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""Downsample2D a batch of 2D images with the given filter.
|
r"""Downsample2D a batch of 2D images with the given filter.
|
||||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||||
@@ -946,13 +981,20 @@ def downsample_2d(
|
|||||||
kernel = kernel * gain
|
kernel = kernel * gain
|
||||||
pad_value = kernel.shape[0] - factor
|
pad_value = kernel.shape[0] - factor
|
||||||
output = upfirdn2d_native(
|
output = upfirdn2d_native(
|
||||||
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
|
hidden_states,
|
||||||
|
kernel.to(device=hidden_states.device),
|
||||||
|
down=factor,
|
||||||
|
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d_native(
|
def upfirdn2d_native(
|
||||||
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
|
tensor: torch.Tensor,
|
||||||
|
kernel: torch.Tensor,
|
||||||
|
up: int = 1,
|
||||||
|
down: int = 1,
|
||||||
|
pad: Tuple[int, int] = (0, 0),
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
up_x = up_y = up
|
up_x = up_y = up
|
||||||
down_x = down_y = down
|
down_x = down_y = down
|
||||||
@@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
|
|||||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
out_dim: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_dim = out_dim or in_dim
|
out_dim = out_dim or in_dim
|
||||||
self.in_dim = in_dim
|
self.in_dim = in_dim
|
||||||
@@ -1016,7 +1064,9 @@ class TemporalConvLayer(nn.Module):
|
|||||||
|
|
||||||
# conv layers
|
# conv layers
|
||||||
self.conv1 = nn.Sequential(
|
self.conv1 = nn.Sequential(
|
||||||
nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
|
nn.GroupNorm(norm_num_groups, in_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
|
||||||
)
|
)
|
||||||
self.conv2 = nn.Sequential(
|
self.conv2 = nn.Sequential(
|
||||||
nn.GroupNorm(norm_num_groups, out_dim),
|
nn.GroupNorm(norm_num_groups, out_dim),
|
||||||
@@ -1058,3 +1108,295 @@ class TemporalConvLayer(nn.Module):
|
|||||||
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalResnetBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
A Resnet block.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
in_channels (`int`): The number of channels in the input.
|
||||||
|
out_channels (`int`, *optional*, default to be `None`):
|
||||||
|
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
||||||
|
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||||
|
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
||||||
|
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
||||||
|
groups_out (`int`, *optional*, default to None):
|
||||||
|
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
|
||||||
|
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
||||||
|
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
||||||
|
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
||||||
|
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
|
||||||
|
"ada_group" for a stronger conditioning with scale and shift.
|
||||||
|
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
|
||||||
|
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
||||||
|
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
||||||
|
use_in_shortcut (`bool`, *optional*, default to `True`):
|
||||||
|
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
|
||||||
|
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
|
||||||
|
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
|
||||||
|
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
|
||||||
|
`conv_shortcut` output.
|
||||||
|
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
|
||||||
|
If None, same as `out_channels`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
conv_shortcut: bool = False,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
temb_channels: int = 512,
|
||||||
|
groups: int = 32,
|
||||||
|
groups_out: Optional[int] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
non_linearity: str = "swish",
|
||||||
|
kernel_size: Optional[torch.FloatTensor] = (3, 1, 1),
|
||||||
|
output_scale_factor: float = 1.0,
|
||||||
|
use_in_shortcut: Optional[bool] = None,
|
||||||
|
conv_shortcut_bias: bool = True,
|
||||||
|
conv_2d_out_channels: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.output_scale_factor = output_scale_factor
|
||||||
|
|
||||||
|
linear_cls = nn.Linear
|
||||||
|
conv_cls = nn.Conv3d
|
||||||
|
|
||||||
|
padding = [k // 2 for k in kernel_size]
|
||||||
|
|
||||||
|
if groups_out is None:
|
||||||
|
groups_out = groups
|
||||||
|
|
||||||
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
|
||||||
|
self.conv1 = conv_cls(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
if temb_channels is not None:
|
||||||
|
self.time_emb_proj = linear_cls(temb_channels, out_channels)
|
||||||
|
else:
|
||||||
|
self.time_emb_proj = None
|
||||||
|
|
||||||
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||||
|
|
||||||
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
|
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||||
|
self.conv2 = conv_cls(
|
||||||
|
out_channels,
|
||||||
|
conv_2d_out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.nonlinearity = get_activation(non_linearity)
|
||||||
|
|
||||||
|
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
|
||||||
|
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if self.use_in_shortcut:
|
||||||
|
self.conv_shortcut = conv_cls(
|
||||||
|
in_channels,
|
||||||
|
conv_2d_out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=conv_shortcut_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
hidden_states = input_tensor
|
||||||
|
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
if self.time_emb_proj is not None:
|
||||||
|
temb = self.nonlinearity(temb)
|
||||||
|
temb = self.time_emb_proj(temb)[:, :, :, None, None]
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = temb.permute(0, 2, 1, 3, 4)
|
||||||
|
hidden_states = hidden_states + temb
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||||
|
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
# VideoResBlock
|
||||||
|
class SpatioTemporalResBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
temb_channels: int = 512,
|
||||||
|
groups: int = 32,
|
||||||
|
pre_norm: bool = True,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
temporal_eps: Optional[float] = None,
|
||||||
|
non_linearity: str = "swish",
|
||||||
|
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
|
||||||
|
output_scale_factor: float = 1.0,
|
||||||
|
kernel_size_3d: Optional[torch.FloatTensor] = (3, 1, 1),
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
merge_strategy="learned",
|
||||||
|
switch_spatial_to_temporal_mix: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.spatial_res_block = ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=eps,
|
||||||
|
groups=groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=time_embedding_norm,
|
||||||
|
non_linearity=non_linearity,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=pre_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.temporal_res_block = TemporalResnetBlock(
|
||||||
|
in_channels=out_channels if out_channels is not None else in_channels,
|
||||||
|
out_channels=out_channels if out_channels is not None else in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=temporal_eps if temporal_eps is not None else eps,
|
||||||
|
groups=groups,
|
||||||
|
dropout=dropout,
|
||||||
|
non_linearity=non_linearity,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
kernel_size=kernel_size_3d,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.time_mixer = AlphaBlender(
|
||||||
|
alpha=merge_factor,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
num_frames: int = 1,
|
||||||
|
image_only_indicator: Optional[torch.Tensor] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
):
|
||||||
|
hidden_states = self.spatial_res_block(hidden_states, temb, scale=scale)
|
||||||
|
|
||||||
|
batch_frames, channels, height, width = hidden_states.shape
|
||||||
|
batch_size = batch_frames // num_frames
|
||||||
|
|
||||||
|
hidden_states_mix = (
|
||||||
|
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||||
|
)
|
||||||
|
hidden_states = (
|
||||||
|
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = temb.reshape(batch_size, num_frames, -1)
|
||||||
|
|
||||||
|
hidden_states = self.temporal_res_block(hidden_states, temb)
|
||||||
|
hidden_states = self.time_mixer(
|
||||||
|
x_spatial=hidden_states_mix,
|
||||||
|
x_temporal=hidden_states,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AlphaBlender(nn.Module):
|
||||||
|
strategies = ["learned", "fixed", "learned_with_images"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
alpha: float,
|
||||||
|
merge_strategy: str = "learned_with_images",
|
||||||
|
switch_spatial_to_temporal_mix: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.merge_strategy = merge_strategy
|
||||||
|
self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
|
||||||
|
|
||||||
|
assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"
|
||||||
|
|
||||||
|
if self.merge_strategy == "fixed":
|
||||||
|
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||||
|
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
|
||||||
|
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
|
||||||
|
|
||||||
|
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
|
||||||
|
if self.merge_strategy == "fixed":
|
||||||
|
alpha = self.mix_factor
|
||||||
|
|
||||||
|
elif self.merge_strategy == "learned":
|
||||||
|
alpha = torch.sigmoid(self.mix_factor)
|
||||||
|
|
||||||
|
elif self.merge_strategy == "learned_with_images":
|
||||||
|
assert (
|
||||||
|
image_only_indicator is not None
|
||||||
|
), "Please provide image_only_indicator to use learned_with_images merge strategy"
|
||||||
|
alpha = torch.where(
|
||||||
|
image_only_indicator.bool(),
|
||||||
|
torch.ones(1, 1, device=image_only_indicator.device),
|
||||||
|
torch.sigmoid(self.mix_factor)[..., None],
|
||||||
|
)
|
||||||
|
|
||||||
|
# (batch, channel, frames, height, width)
|
||||||
|
if ndims == 5:
|
||||||
|
alpha = alpha[:, None, :, None, None]
|
||||||
|
# (batch*frames, height*width, channels)
|
||||||
|
elif ndims == 3:
|
||||||
|
alpha = alpha.reshape(-1)[:, None, None]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return alpha
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_spatial: torch.Tensor,
|
||||||
|
x_temporal: torch.Tensor,
|
||||||
|
image_only_indicator: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
|
||||||
|
alpha = alpha.to(x_spatial.dtype)
|
||||||
|
|
||||||
|
if self.switch_spatial_to_temporal_mix:
|
||||||
|
alpha = 1.0 - alpha
|
||||||
|
|
||||||
|
x = alpha * x_spatial + (1.0 - alpha) * x_temporal
|
||||||
|
return x
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ from torch import nn
|
|||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from .attention import BasicTransformerBlock
|
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||||
|
from .embeddings import TimestepEmbedding, Timesteps
|
||||||
from .modeling_utils import ModelMixin
|
from .modeling_utils import ModelMixin
|
||||||
|
from .resnet import AlphaBlender
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -195,3 +197,229 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
|||||||
return (output,)
|
return (output,)
|
||||||
|
|
||||||
return TransformerTemporalModelOutput(sample=output)
|
return TransformerTemporalModelOutput(sample=output)
|
||||||
|
|
||||||
|
|
||||||
|
# VideoBlock
|
||||||
|
class TransformerSpatioTemporalModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
A Transformer model for video-like data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||||
|
in_channels (`int`, *optional*):
|
||||||
|
The number of channels in the input and output (specify if the input is **continuous**).
|
||||||
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||||
|
attention_bias (`bool`, *optional*):
|
||||||
|
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
||||||
|
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||||
|
This is fixed during training since it is used to learn a number of position embeddings.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||||
|
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
||||||
|
activation functions.
|
||||||
|
norm_elementwise_affine (`bool`, *optional*):
|
||||||
|
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
||||||
|
double_self_attention (`bool`, *optional*):
|
||||||
|
Configure if each `TransformerBlock` should contain two self-attention layers.
|
||||||
|
positional_embeddings: (`str`, *optional*):
|
||||||
|
The type of positional embeddings to apply to the sequence input before passing use.
|
||||||
|
num_positional_embeddings: (`int`, *optional*):
|
||||||
|
The maximum length of the sequence over which to apply positional embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
attention_head_dim: int = 88,
|
||||||
|
in_channels: int = 320,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
num_layers: int = 1,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
cross_attention_dim: Optional[int] = None,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
merge_strategy: str = "learned_with_images",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
|
||||||
|
linear_cls = nn.Linear
|
||||||
|
|
||||||
|
# 2. Define input layers
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps)
|
||||||
|
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||||
|
|
||||||
|
# 3. Define transformers blocks
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock(
|
||||||
|
inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
for d in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
time_mix_inner_dim = inner_dim
|
||||||
|
self.temporal_transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TemporalBasicTransformerBlock(
|
||||||
|
inner_dim,
|
||||||
|
time_mix_inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
time_embed_dim = in_channels * 4
|
||||||
|
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
||||||
|
self.time_proj = Timesteps(in_channels, True, 0)
|
||||||
|
self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy)
|
||||||
|
|
||||||
|
# 4. Define output layers
|
||||||
|
self.out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
# TODO: should use out_channels for continuous projections
|
||||||
|
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
num_frames: int,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
timestep: Optional[torch.LongTensor] = None,
|
||||||
|
class_labels: Optional[torch.LongTensor] = None,
|
||||||
|
cross_attention_kwargs: Dict[str, Any] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
image_only_indicator: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||||
|
Input hidden_states.
|
||||||
|
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||||
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||||
|
self-attention.
|
||||||
|
timestep ( `torch.LongTensor`, *optional*):
|
||||||
|
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||||
|
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||||
|
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||||
|
`AdaLayerZeroNorm`.
|
||||||
|
num_frames (`int`, *optional*, defaults to 1):
|
||||||
|
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||||
|
tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||||
|
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||||
|
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
encoder_hidden_states.ndim == 3
|
||||||
|
), f"n dims of spatial context should be 3 but are {encoder_hidden_states.ndim}"
|
||||||
|
|
||||||
|
# 1. Input
|
||||||
|
batch_frames, channel, height, width = hidden_states.shape
|
||||||
|
batch_size = batch_frames // num_frames
|
||||||
|
|
||||||
|
time_context = encoder_hidden_states
|
||||||
|
time_context_first_timestep = time_context[::num_frames]
|
||||||
|
time_context = time_context_first_timestep.repeat(height * width, 1, 1)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
inner_dim = hidden_states.shape[1]
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
|
||||||
|
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
||||||
|
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
||||||
|
num_frames_emb = num_frames_emb.reshape(-1)
|
||||||
|
t_emb = self.time_proj(num_frames_emb)
|
||||||
|
|
||||||
|
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
emb = self.time_pos_embed(t_emb)
|
||||||
|
emb = emb[:, None, :]
|
||||||
|
|
||||||
|
# 2. Blocks
|
||||||
|
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
block,
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
encoder_hidden_states,
|
||||||
|
None,
|
||||||
|
timestep,
|
||||||
|
cross_attention_kwargs,
|
||||||
|
class_labels,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
timestep=timestep,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
class_labels=class_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states_mix = hidden_states
|
||||||
|
hidden_states_mix = hidden_states_mix + emb
|
||||||
|
|
||||||
|
hidden_states_mix = temporal_block(
|
||||||
|
hidden_states_mix,
|
||||||
|
num_frames=num_frames,
|
||||||
|
encoder_hidden_states=time_context,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = self.time_mixer(
|
||||||
|
x_spatial=hidden_states,
|
||||||
|
x_temporal=hidden_states_mix,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Output
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
output = hidden_states + residual
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (output,)
|
||||||
|
|
||||||
|
return TransformerTemporalModelOutput(sample=output)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
859
src/diffusers/models/unet_spatio_temporal_condition.py
Normal file
859
src/diffusers/models/unet_spatio_temporal_condition.py
Normal file
@@ -0,0 +1,859 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..loaders import UNet2DConditionLoadersMixin
|
||||||
|
from ..utils import (
|
||||||
|
USE_PEFT_BACKEND,
|
||||||
|
BaseOutput,
|
||||||
|
logging,
|
||||||
|
scale_lora_layers,
|
||||||
|
unscale_lora_layers,
|
||||||
|
)
|
||||||
|
from .activations import get_activation
|
||||||
|
from .attention_processor import (
|
||||||
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
|
AttentionProcessor,
|
||||||
|
AttnAddedKVProcessor,
|
||||||
|
AttnProcessor,
|
||||||
|
)
|
||||||
|
from .embeddings import TimestepEmbedding, Timesteps
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
|
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UNetSpatioTemporalConditionOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
The output of [`UNetSpatioTemporalConditionModel`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sample: torch.FloatTensor = None
|
||||||
|
|
||||||
|
|
||||||
|
class UNetSpatioTemporalConditionModel(
|
||||||
|
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||||
|
shaped output.
|
||||||
|
|
||||||
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||||
|
for all models (such as downloading or saving).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||||
|
Height and width of input/output sample.
|
||||||
|
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
||||||
|
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
||||||
|
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||||
|
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to flip the sin to cos in the time embedding.
|
||||||
|
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||||
|
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||||
|
The tuple of downsample blocks to use.
|
||||||
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||||
|
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
||||||
|
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
||||||
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
||||||
|
The tuple of upsample blocks to use.
|
||||||
|
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
||||||
|
Whether to include self-attention in the basic transformer blocks, see
|
||||||
|
[`~models.attention.BasicTransformerBlock`].
|
||||||
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||||
|
The tuple of output channels for each block.
|
||||||
|
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
||||||
|
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
||||||
|
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||||
|
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
||||||
|
If `None`, normalization and activation layers is skipped in post-processing.
|
||||||
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||||
|
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||||
|
The dimension of the cross attention features.
|
||||||
|
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
||||||
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||||
|
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||||
|
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||||
|
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
||||||
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
||||||
|
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
||||||
|
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||||
|
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||||
|
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||||
|
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||||
|
dimension to `cross_attention_dim`.
|
||||||
|
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||||
|
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||||
|
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||||
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||||
|
num_attention_heads (`int`, *optional*):
|
||||||
|
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
||||||
|
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||||
|
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||||
|
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||||
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||||
|
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||||
|
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||||
|
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
||||||
|
Dimension for the timestep embeddings.
|
||||||
|
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||||
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
||||||
|
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
||||||
|
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
An optional override for the dimension of the projected time embedding.
|
||||||
|
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
||||||
|
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
||||||
|
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
||||||
|
timestep_post_act (`str`, *optional*, defaults to `None`):
|
||||||
|
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
||||||
|
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
The dimension of `cond_proj` layer in the timestep embedding.
|
||||||
|
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
|
||||||
|
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
|
||||||
|
*optional*): The dimension of the `class_labels` input when
|
||||||
|
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
||||||
|
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
||||||
|
embeddings with the class embeddings.
|
||||||
|
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
||||||
|
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
||||||
|
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
||||||
|
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
||||||
|
otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_size: Optional[int] = None,
|
||||||
|
in_channels: int = 8,
|
||||||
|
out_channels: int = 4,
|
||||||
|
center_input_sample: bool = False,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
down_block_types: Tuple[str] = (
|
||||||
|
"CrossAttnDownBlockSpatioTemporal",
|
||||||
|
"CrossAttnDownBlockSpatioTemporal",
|
||||||
|
"CrossAttnDownBlockSpatioTemporal",
|
||||||
|
"DownBlockSpatioTemporal",
|
||||||
|
),
|
||||||
|
mid_block_type: Optional[str] = "UNetMidBlockSpatioTemporal",
|
||||||
|
up_block_types: Tuple[str] = (
|
||||||
|
"UpBlockSpatioTemporal",
|
||||||
|
"CrossAttnUpBlockSpatioTemporal",
|
||||||
|
"CrossAttnUpBlockSpatioTemporal",
|
||||||
|
"CrossAttnUpBlockSpatioTemporal",
|
||||||
|
),
|
||||||
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||||
|
projection_class_embeddings_input_dim: int = 768,
|
||||||
|
addition_time_embed_dim: int = 256,
|
||||||
|
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||||
|
mid_block_scale_factor: float = 1,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
norm_num_groups: Optional[int] = 32,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||||
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||||
|
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||||
|
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||||
|
time_embedding_dim: Optional[int] = None,
|
||||||
|
conv_in_kernel: int = 3,
|
||||||
|
conv_out_kernel: int = 3,
|
||||||
|
kernel_size_3d: Optional[torch.FloatTensor] = (3, 1, 1),
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
merge_strategy: str = "learned_with_images",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.sample_size = sample_size
|
||||||
|
|
||||||
|
if num_attention_heads is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||||
|
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||||
|
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||||
|
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||||
|
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||||
|
# which is why we correct for the naming here.
|
||||||
|
num_attention_heads = num_attention_heads or attention_head_dim
|
||||||
|
|
||||||
|
# Check inputs
|
||||||
|
if len(down_block_types) != len(up_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(block_out_channels) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
||||||
|
down_block_types
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
|
||||||
|
down_block_types
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
|
||||||
|
down_block_types
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
|
||||||
|
down_block_types
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# input
|
||||||
|
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
block_out_channels[0],
|
||||||
|
kernel_size=conv_in_kernel,
|
||||||
|
padding=conv_in_padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
# time
|
||||||
|
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||||
|
|
||||||
|
self.time_proj = Timesteps(
|
||||||
|
block_out_channels[0], flip_sin_to_cos, downscale_freq_shift=0
|
||||||
|
)
|
||||||
|
timestep_input_dim = block_out_channels[0]
|
||||||
|
|
||||||
|
self.time_embedding = TimestepEmbedding(
|
||||||
|
timestep_input_dim,
|
||||||
|
time_embed_dim,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.add_time_proj = Timesteps(
|
||||||
|
addition_time_embed_dim, flip_sin_to_cos, downscale_freq_shift=0
|
||||||
|
)
|
||||||
|
self.add_embedding = TimestepEmbedding(
|
||||||
|
projection_class_embeddings_input_dim, time_embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
if isinstance(num_attention_heads, int):
|
||||||
|
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(attention_head_dim, int):
|
||||||
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(cross_attention_dim, int):
|
||||||
|
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(layers_per_block, int):
|
||||||
|
layers_per_block = [layers_per_block] * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(transformer_layers_per_block, int):
|
||||||
|
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
||||||
|
down_block_types
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks_time_embed_dim = time_embed_dim
|
||||||
|
|
||||||
|
# down
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
down_block = get_down_block(
|
||||||
|
down_block_type,
|
||||||
|
num_layers=layers_per_block[i],
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=blocks_time_embed_dim,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim[i],
|
||||||
|
num_attention_heads=num_attention_heads[i],
|
||||||
|
downsample_padding=1,
|
||||||
|
dropout=dropout,
|
||||||
|
kernel_size_3d=kernel_size_3d,
|
||||||
|
merge_factor=merge_factor,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
self.mid_block = UNetMidBlockSpatioTemporal(
|
||||||
|
block_out_channels[-1],
|
||||||
|
temb_channels=blocks_time_embed_dim,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
cross_attention_dim=cross_attention_dim[-1],
|
||||||
|
num_attention_heads=num_attention_heads[-1],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
kernel_size_3d=kernel_size_3d,
|
||||||
|
merge_factor=merge_factor,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# count how many layers upsample the images
|
||||||
|
self.num_upsamplers = 0
|
||||||
|
|
||||||
|
# up
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||||
|
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||||
|
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||||
|
reversed_transformer_layers_per_block = list(
|
||||||
|
reversed(transformer_layers_per_block)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
input_channel = reversed_block_out_channels[
|
||||||
|
min(i + 1, len(block_out_channels) - 1)
|
||||||
|
]
|
||||||
|
|
||||||
|
# add upsample block for all BUT final layer
|
||||||
|
if not is_final_block:
|
||||||
|
add_upsample = True
|
||||||
|
self.num_upsamplers += 1
|
||||||
|
else:
|
||||||
|
add_upsample = False
|
||||||
|
|
||||||
|
up_block = get_up_block(
|
||||||
|
up_block_type,
|
||||||
|
num_layers=reversed_layers_per_block[i] + 1,
|
||||||
|
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=blocks_time_embed_dim,
|
||||||
|
add_upsample=add_upsample,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resolution_idx=i,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=reversed_cross_attention_dim[i],
|
||||||
|
num_attention_heads=reversed_num_attention_heads[i],
|
||||||
|
dropout=dropout,
|
||||||
|
kernel_size_3d=kernel_size_3d,
|
||||||
|
merge_factor=merge_factor,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
|
# out
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||||
|
)
|
||||||
|
self.conv_act = get_activation(act_fn)
|
||||||
|
|
||||||
|
conv_out_padding = (conv_out_kernel - 1) // 2
|
||||||
|
self.conv_out = nn.Conv2d(
|
||||||
|
block_out_channels[0],
|
||||||
|
out_channels,
|
||||||
|
kernel_size=conv_out_kernel,
|
||||||
|
padding=conv_out_padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(
|
||||||
|
name: str,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
processors: Dict[str, AttentionProcessor],
|
||||||
|
):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor(
|
||||||
|
return_deprecated_lora=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
def set_attn_processor(
|
||||||
|
self,
|
||||||
|
processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
|
||||||
|
_remove_lora=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||||
|
else:
|
||||||
|
module.set_processor(
|
||||||
|
processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
|
||||||
|
)
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
def set_default_attn_processor(self):
|
||||||
|
"""
|
||||||
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
|
"""
|
||||||
|
if all(
|
||||||
|
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
|
||||||
|
for proc in self.attn_processors.values()
|
||||||
|
):
|
||||||
|
processor = AttnAddedKVProcessor()
|
||||||
|
elif all(
|
||||||
|
proc.__class__ in CROSS_ATTENTION_PROCESSORS
|
||||||
|
for proc in self.attn_processors.values()
|
||||||
|
):
|
||||||
|
processor = AttnProcessor()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_attn_processor(processor, _remove_lora=True)
|
||||||
|
|
||||||
|
def set_attention_slice(self, slice_size):
|
||||||
|
r"""
|
||||||
|
Enable sliced attention computation.
|
||||||
|
|
||||||
|
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||||
|
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||||
|
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||||
|
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||||
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||||
|
must be a multiple of `slice_size`.
|
||||||
|
"""
|
||||||
|
sliceable_head_dims = []
|
||||||
|
|
||||||
|
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(child)
|
||||||
|
|
||||||
|
# retrieve number of attention layers
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(module)
|
||||||
|
|
||||||
|
num_sliceable_layers = len(sliceable_head_dims)
|
||||||
|
|
||||||
|
if slice_size == "auto":
|
||||||
|
# half the attention head size is usually a good trade-off between
|
||||||
|
# speed and memory
|
||||||
|
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||||
|
elif slice_size == "max":
|
||||||
|
# make smallest slice possible
|
||||||
|
slice_size = num_sliceable_layers * [1]
|
||||||
|
|
||||||
|
slice_size = (
|
||||||
|
num_sliceable_layers * [slice_size]
|
||||||
|
if not isinstance(slice_size, list)
|
||||||
|
else slice_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(slice_size) != len(sliceable_head_dims):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||||
|
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(len(slice_size)):
|
||||||
|
size = slice_size[i]
|
||||||
|
dim = sliceable_head_dims[i]
|
||||||
|
if size is not None and size > dim:
|
||||||
|
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||||
|
|
||||||
|
# Recursively walk through all the children.
|
||||||
|
# Any children which exposes the set_attention_slice method
|
||||||
|
# gets the message
|
||||||
|
def fn_recursive_set_attention_slice(
|
||||||
|
module: torch.nn.Module, slice_size: List[int]
|
||||||
|
):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
module.set_attention_slice(slice_size.pop())
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_attention_slice(child, slice_size)
|
||||||
|
|
||||||
|
reversed_slice_size = list(reversed(slice_size))
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if hasattr(module, "gradient_checkpointing"):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def enable_freeu(self, s1, s2, b1, b2):
|
||||||
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||||
|
|
||||||
|
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||||
|
|
||||||
|
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||||
|
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s1 (`float`):
|
||||||
|
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||||
|
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||||
|
s2 (`float`):
|
||||||
|
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||||
|
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||||
|
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||||
|
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||||
|
"""
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
setattr(upsample_block, "s1", s1)
|
||||||
|
setattr(upsample_block, "s2", s2)
|
||||||
|
setattr(upsample_block, "b1", b1)
|
||||||
|
setattr(upsample_block, "b2", b2)
|
||||||
|
|
||||||
|
def disable_freeu(self):
|
||||||
|
"""Disables the FreeU mechanism."""
|
||||||
|
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
for k in freeu_keys:
|
||||||
|
if (
|
||||||
|
hasattr(upsample_block, k)
|
||||||
|
or getattr(upsample_block, k, None) is not None
|
||||||
|
):
|
||||||
|
setattr(upsample_block, k, None)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
added_time_ids: torch.Tensor,
|
||||||
|
image_only_indicator: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
||||||
|
r"""
|
||||||
|
The [`UNet2DConditionModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||||
|
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.FloatTensor`):
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
added_cond_kwargs: (`dict`):
|
||||||
|
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||||
|
are passed along to the UNet blocks.
|
||||||
|
encoder_attention_mask (`torch.Tensor`):
|
||||||
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||||
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||||
|
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||||
|
tuple.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||||
|
Returns:
|
||||||
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||||
|
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||||
|
a `tuple` is returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||||
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||||
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||||
|
# on the fly if necessary.
|
||||||
|
default_overall_up_factor = 2**self.num_upsamplers
|
||||||
|
|
||||||
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||||
|
forward_upsample_size = False
|
||||||
|
upsample_size = None
|
||||||
|
|
||||||
|
for dim in sample.shape[-2:]:
|
||||||
|
if dim % default_overall_up_factor != 0:
|
||||||
|
# Forward upsample size to force interpolation output size.
|
||||||
|
forward_upsample_size = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||||
|
# expects mask of shape:
|
||||||
|
# [batch, key_tokens]
|
||||||
|
# adds singleton query_tokens dimension:
|
||||||
|
# [batch, 1, key_tokens]
|
||||||
|
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||||
|
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||||
|
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# assume that mask is expressed as:
|
||||||
|
# (1 = keep, 0 = discard)
|
||||||
|
# convert mask into a bias that can be added to attention scores:
|
||||||
|
# (keep = +0, discard = -10000.0)
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = (
|
||||||
|
1 - encoder_attention_mask.to(sample.dtype)
|
||||||
|
) * -10000.0
|
||||||
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 0. center input if necessary
|
||||||
|
if self.config.center_input_sample:
|
||||||
|
sample = 2 * sample - 1.0
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
# This would be a good case for the `match` statement (Python 3.10+)
|
||||||
|
is_mps = sample.device.type == "mps"
|
||||||
|
if isinstance(timestep, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||||
|
elif len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
batch_size, num_frames = sample.shape[:2]
|
||||||
|
timesteps = timesteps.expand(batch_size)
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
|
||||||
|
emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
||||||
|
time_embeds = time_embeds.reshape((batch_size, -1))
|
||||||
|
time_embeds = time_embeds.to(emb.dtype)
|
||||||
|
aug_emb = self.add_embedding(time_embeds)
|
||||||
|
emb = emb + aug_emb
|
||||||
|
|
||||||
|
# Flatten the batch and frames dimensions
|
||||||
|
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
||||||
|
sample = sample.flatten(0, 1)
|
||||||
|
# Repeat the embeddings num_video_frames times
|
||||||
|
# emb: [batch, channels] -> [batch * frames, channels]
|
||||||
|
emb = emb.repeat_interleave(num_frames, dim=0)
|
||||||
|
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
||||||
|
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
||||||
|
num_frames, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
lora_scale = (
|
||||||
|
cross_attention_kwargs.get("scale", 1.0)
|
||||||
|
if cross_attention_kwargs is not None
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||||
|
scale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
image_only_indicator = torch.zeros(
|
||||||
|
batch_size, num_frames, dtype=sample.dtype, device=sample.device
|
||||||
|
)
|
||||||
|
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if (
|
||||||
|
hasattr(downsample_block, "has_cross_attention")
|
||||||
|
and downsample_block.has_cross_attention
|
||||||
|
):
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
num_video_frames=num_frames,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
scale=lora_scale,
|
||||||
|
num_video_frames=num_frames,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
if (
|
||||||
|
hasattr(self.mid_block, "has_cross_attention")
|
||||||
|
and self.mid_block.has_cross_attention
|
||||||
|
):
|
||||||
|
sample = self.mid_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
num_video_frames=num_frames,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
temb=emb,
|
||||||
|
num_video_frames=num_frames,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. up
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
|
down_block_res_samples = down_block_res_samples[
|
||||||
|
: -len(upsample_block.resnets)
|
||||||
|
]
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the
|
||||||
|
# upsample size, we do it here
|
||||||
|
if not is_final_block and forward_upsample_size:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(upsample_block, "has_cross_attention")
|
||||||
|
and upsample_block.has_cross_attention
|
||||||
|
):
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
num_video_frames=num_frames,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
scale=lora_scale,
|
||||||
|
num_video_frames=num_frames,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. post-process
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
# 7. Reshape back to original shape
|
||||||
|
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
||||||
|
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# remove `lora_scale` from each PEFT layer
|
||||||
|
unscale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
|
||||||
|
return UNetSpatioTemporalConditionOutput(sample=sample)
|
||||||
@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version
|
|||||||
from ..utils.torch_utils import randn_tensor
|
from ..utils.torch_utils import randn_tensor
|
||||||
from .activations import get_activation
|
from .activations import get_activation
|
||||||
from .attention_processor import SpatialNorm
|
from .attention_processor import SpatialNorm
|
||||||
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
|
from .unet_2d_blocks import (
|
||||||
|
AutoencoderTinyBlock,
|
||||||
|
UNetMidBlock2D,
|
||||||
|
get_down_block,
|
||||||
|
get_up_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -122,11 +127,15 @@ class Encoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# out
|
# out
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
|
||||||
|
)
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
|
|
||||||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
self.conv_out = nn.Conv2d(
|
||||||
|
block_out_channels[-1], conv_out_channels, 3, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
@@ -155,9 +164,13 @@ class Encoder(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for down_block in self.down_blocks:
|
for down_block in self.down_blocks:
|
||||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(down_block), sample
|
||||||
|
)
|
||||||
# middle
|
# middle
|
||||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.mid_block), sample
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# down
|
# down
|
||||||
@@ -267,14 +280,18 @@ class Decoder(nn.Module):
|
|||||||
if norm_type == "spatial":
|
if norm_type == "spatial":
|
||||||
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
||||||
else:
|
else:
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
||||||
|
)
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Decoder` class."""
|
r"""The forward method of the `Decoder` class."""
|
||||||
|
|
||||||
@@ -292,14 +309,20 @@ class Decoder(nn.Module):
|
|||||||
if is_torch_version(">=", "1.11.0"):
|
if is_torch_version(">=", "1.11.0"):
|
||||||
# middle
|
# middle
|
||||||
sample = torch.utils.checkpoint.checkpoint(
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
|
create_custom_forward(self.mid_block),
|
||||||
|
sample,
|
||||||
|
latent_embeds,
|
||||||
|
use_reentrant=False,
|
||||||
)
|
)
|
||||||
sample = sample.to(upscale_dtype)
|
sample = sample.to(upscale_dtype)
|
||||||
|
|
||||||
# up
|
# up
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
sample = torch.utils.checkpoint.checkpoint(
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
|
create_custom_forward(up_block),
|
||||||
|
sample,
|
||||||
|
latent_embeds,
|
||||||
|
use_reentrant=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# middle
|
# middle
|
||||||
@@ -310,7 +333,9 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# up
|
# up
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(up_block), sample, latent_embeds
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# middle
|
# middle
|
||||||
sample = self.mid_block(sample, latent_embeds)
|
sample = self.mid_block(sample, latent_embeds)
|
||||||
@@ -350,7 +375,9 @@ class UpSample(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
self.deconv = nn.ConvTranspose2d(
|
||||||
|
in_channels, out_channels, kernel_size=4, stride=2, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `UpSample` class."""
|
r"""The forward method of the `UpSample` class."""
|
||||||
@@ -394,9 +421,13 @@ class MaskConditionEncoder(nn.Module):
|
|||||||
for l in range(len(out_channels)):
|
for l in range(len(out_channels)):
|
||||||
out_ch_ = out_channels[l]
|
out_ch_ = out_channels[l]
|
||||||
if l == 0 or l == 1:
|
if l == 0 or l == 1:
|
||||||
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
|
layers.append(
|
||||||
|
nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
|
layers.append(
|
||||||
|
nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)
|
||||||
|
)
|
||||||
in_ch_ = out_ch_
|
in_ch_ = out_ch_
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
@@ -511,7 +542,9 @@ class MaskConditionDecoder(nn.Module):
|
|||||||
if norm_type == "spatial":
|
if norm_type == "spatial":
|
||||||
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
||||||
else:
|
else:
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
||||||
|
)
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||||
|
|
||||||
@@ -540,7 +573,10 @@ class MaskConditionDecoder(nn.Module):
|
|||||||
if is_torch_version(">=", "1.11.0"):
|
if is_torch_version(">=", "1.11.0"):
|
||||||
# middle
|
# middle
|
||||||
sample = torch.utils.checkpoint.checkpoint(
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
|
create_custom_forward(self.mid_block),
|
||||||
|
sample,
|
||||||
|
latent_embeds,
|
||||||
|
use_reentrant=False,
|
||||||
)
|
)
|
||||||
sample = sample.to(upscale_dtype)
|
sample = sample.to(upscale_dtype)
|
||||||
|
|
||||||
@@ -548,17 +584,25 @@ class MaskConditionDecoder(nn.Module):
|
|||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
masked_image = (1 - mask) * image
|
masked_image = (1 - mask) * image
|
||||||
im_x = torch.utils.checkpoint.checkpoint(
|
im_x = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
|
create_custom_forward(self.condition_encoder),
|
||||||
|
masked_image,
|
||||||
|
mask,
|
||||||
|
use_reentrant=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# up
|
# up
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
sample_ = im_x[str(tuple(sample.shape))]
|
sample_ = im_x[str(tuple(sample.shape))]
|
||||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
mask_ = nn.functional.interpolate(
|
||||||
|
mask, size=sample.shape[-2:], mode="nearest"
|
||||||
|
)
|
||||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||||
sample = torch.utils.checkpoint.checkpoint(
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
|
create_custom_forward(up_block),
|
||||||
|
sample,
|
||||||
|
latent_embeds,
|
||||||
|
use_reentrant=False,
|
||||||
)
|
)
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||||
@@ -573,16 +617,22 @@ class MaskConditionDecoder(nn.Module):
|
|||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
masked_image = (1 - mask) * image
|
masked_image = (1 - mask) * image
|
||||||
im_x = torch.utils.checkpoint.checkpoint(
|
im_x = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(self.condition_encoder), masked_image, mask
|
create_custom_forward(self.condition_encoder),
|
||||||
|
masked_image,
|
||||||
|
mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# up
|
# up
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
sample_ = im_x[str(tuple(sample.shape))]
|
sample_ = im_x[str(tuple(sample.shape))]
|
||||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
mask_ = nn.functional.interpolate(
|
||||||
|
mask, size=sample.shape[-2:], mode="nearest"
|
||||||
|
)
|
||||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
sample = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(up_block), sample, latent_embeds
|
||||||
|
)
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||||
else:
|
else:
|
||||||
@@ -599,7 +649,9 @@ class MaskConditionDecoder(nn.Module):
|
|||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
sample_ = im_x[str(tuple(sample.shape))]
|
sample_ = im_x[str(tuple(sample.shape))]
|
||||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
mask_ = nn.functional.interpolate(
|
||||||
|
mask, size=sample.shape[-2:], mode="nearest"
|
||||||
|
)
|
||||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||||
sample = up_block(sample, latent_embeds)
|
sample = up_block(sample, latent_embeds)
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
@@ -671,7 +723,9 @@ class VectorQuantizer(nn.Module):
|
|||||||
new = match.argmax(-1)
|
new = match.argmax(-1)
|
||||||
unknown = match.sum(2) < 1
|
unknown = match.sum(2) < 1
|
||||||
if self.unknown_index == "random":
|
if self.unknown_index == "random":
|
||||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||||
|
device=new.device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new[unknown] = self.unknown_index
|
new[unknown] = self.unknown_index
|
||||||
return new.reshape(ishape)
|
return new.reshape(ishape)
|
||||||
@@ -686,13 +740,17 @@ class VectorQuantizer(nn.Module):
|
|||||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||||
return back.reshape(ishape)
|
return back.reshape(ishape)
|
||||||
|
|
||||||
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
|
def forward(
|
||||||
|
self, z: torch.FloatTensor
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
|
||||||
# reshape z -> (batch, height, width, channel) and flatten
|
# reshape z -> (batch, height, width, channel) and flatten
|
||||||
z = z.permute(0, 2, 3, 1).contiguous()
|
z = z.permute(0, 2, 3, 1).contiguous()
|
||||||
z_flattened = z.view(-1, self.vq_embed_dim)
|
z_flattened = z.view(-1, self.vq_embed_dim)
|
||||||
|
|
||||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||||
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
|
min_encoding_indices = torch.argmin(
|
||||||
|
torch.cdist(z_flattened, self.embedding.weight), dim=1
|
||||||
|
)
|
||||||
|
|
||||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||||
perplexity = None
|
perplexity = None
|
||||||
@@ -700,9 +758,13 @@ class VectorQuantizer(nn.Module):
|
|||||||
|
|
||||||
# compute loss for embedding
|
# compute loss for embedding
|
||||||
if not self.legacy:
|
if not self.legacy:
|
||||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
||||||
|
(z_q - z.detach()) ** 2
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||||
|
(z_q - z.detach()) ** 2
|
||||||
|
)
|
||||||
|
|
||||||
# preserve gradients
|
# preserve gradients
|
||||||
z_q: torch.FloatTensor = z + (z_q - z).detach()
|
z_q: torch.FloatTensor = z + (z_q - z).detach()
|
||||||
@@ -711,16 +773,22 @@ class VectorQuantizer(nn.Module):
|
|||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
if self.remap is not None:
|
if self.remap is not None:
|
||||||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
min_encoding_indices = min_encoding_indices.reshape(
|
||||||
|
z.shape[0], -1
|
||||||
|
) # add batch axis
|
||||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||||
|
|
||||||
if self.sane_index_shape:
|
if self.sane_index_shape:
|
||||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
min_encoding_indices = min_encoding_indices.reshape(
|
||||||
|
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
||||||
|
)
|
||||||
|
|
||||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||||
|
|
||||||
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
|
def get_codebook_entry(
|
||||||
|
self, indices: torch.LongTensor, shape: Tuple[int, ...]
|
||||||
|
) -> torch.FloatTensor:
|
||||||
# shape specifying (batch, height, width, channel)
|
# shape specifying (batch, height, width, channel)
|
||||||
if self.remap is not None:
|
if self.remap is not None:
|
||||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||||
@@ -754,7 +822,10 @@ class DiagonalGaussianDistribution(object):
|
|||||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
||||||
# make sure sample is on the same device as the parameters and has same dtype
|
# make sure sample is on the same device as the parameters and has same dtype
|
||||||
sample = randn_tensor(
|
sample = randn_tensor(
|
||||||
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
|
self.mean.shape,
|
||||||
|
generator=generator,
|
||||||
|
device=self.parameters.device,
|
||||||
|
dtype=self.parameters.dtype,
|
||||||
)
|
)
|
||||||
x = self.mean + self.std * sample
|
x = self.mean + self.std * sample
|
||||||
return x
|
return x
|
||||||
@@ -764,7 +835,10 @@ class DiagonalGaussianDistribution(object):
|
|||||||
return torch.Tensor([0.0])
|
return torch.Tensor([0.0])
|
||||||
else:
|
else:
|
||||||
if other is None:
|
if other is None:
|
||||||
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
|
return 0.5 * torch.sum(
|
||||||
|
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||||
|
dim=[1, 2, 3],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return 0.5 * torch.sum(
|
return 0.5 * torch.sum(
|
||||||
torch.pow(self.mean - other.mean, 2) / other.var
|
torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
@@ -775,11 +849,16 @@ class DiagonalGaussianDistribution(object):
|
|||||||
dim=[1, 2, 3],
|
dim=[1, 2, 3],
|
||||||
)
|
)
|
||||||
|
|
||||||
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
def nll(
|
||||||
|
self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
|
||||||
|
) -> torch.Tensor:
|
||||||
if self.deterministic:
|
if self.deterministic:
|
||||||
return torch.Tensor([0.0])
|
return torch.Tensor([0.0])
|
||||||
logtwopi = np.log(2.0 * np.pi)
|
logtwopi = np.log(2.0 * np.pi)
|
||||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
return 0.5 * torch.sum(
|
||||||
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||||
|
dim=dims,
|
||||||
|
)
|
||||||
|
|
||||||
def mode(self) -> torch.Tensor:
|
def mode(self) -> torch.Tensor:
|
||||||
return self.mean
|
return self.mean
|
||||||
@@ -818,14 +897,27 @@ class EncoderTiny(nn.Module):
|
|||||||
num_channels = block_out_channels[i]
|
num_channels = block_out_channels[i]
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
|
layers.append(
|
||||||
|
nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
|
layers.append(
|
||||||
|
nn.Conv2d(
|
||||||
|
num_channels,
|
||||||
|
num_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
stride=2,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(num_block):
|
for _ in range(num_block):
|
||||||
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
||||||
|
|
||||||
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
|
layers.append(
|
||||||
|
nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)
|
||||||
|
)
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -841,9 +933,13 @@ class EncoderTiny(nn.Module):
|
|||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
if is_torch_version(">=", "1.11.0"):
|
if is_torch_version(">=", "1.11.0"):
|
||||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.layers), x, use_reentrant=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.layers), x
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
||||||
@@ -899,7 +995,15 @@ class DecoderTiny(nn.Module):
|
|||||||
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
||||||
|
|
||||||
conv_out_channel = num_channels if not is_final_block else out_channels
|
conv_out_channel = num_channels if not is_final_block else out_channels
|
||||||
layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
|
layers.append(
|
||||||
|
nn.Conv2d(
|
||||||
|
num_channels,
|
||||||
|
conv_out_channel,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
bias=is_final_block,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -918,9 +1022,13 @@ class DecoderTiny(nn.Module):
|
|||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
if is_torch_version(">=", "1.11.0"):
|
if is_torch_version(">=", "1.11.0"):
|
||||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.layers), x, use_reentrant=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.layers), x
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ else:
|
|||||||
"StableDiffusionPix2PixZeroPipeline",
|
"StableDiffusionPix2PixZeroPipeline",
|
||||||
"StableDiffusionSAGPipeline",
|
"StableDiffusionSAGPipeline",
|
||||||
"StableDiffusionUpscalePipeline",
|
"StableDiffusionUpscalePipeline",
|
||||||
|
"StableDiffusionVideoPipeline",
|
||||||
"StableUnCLIPImg2ImgPipeline",
|
"StableUnCLIPImg2ImgPipeline",
|
||||||
"StableUnCLIPPipeline",
|
"StableUnCLIPPipeline",
|
||||||
]
|
]
|
||||||
@@ -372,6 +373,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionPix2PixZeroPipeline,
|
StableDiffusionPix2PixZeroPipeline,
|
||||||
StableDiffusionSAGPipeline,
|
StableDiffusionSAGPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
|
StableDiffusionVideoPipeline,
|
||||||
StableUnCLIPImg2ImgPipeline,
|
StableUnCLIPImg2ImgPipeline,
|
||||||
StableUnCLIPPipeline,
|
StableUnCLIPPipeline,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ else:
|
|||||||
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
|
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
|
||||||
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||||
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
|
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
|
||||||
|
_import_structure["pipeline_stable_diffusion_video"] = ["StableDiffusionVideoPipeline"]
|
||||||
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
|
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
|
||||||
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
|
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
|
||||||
_import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"]
|
_import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"]
|
||||||
@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
|
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
|
||||||
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
|
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||||
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
||||||
|
from .pipeline_stable_diffusion_video import StableDiffusionVideoPipeline
|
||||||
from .pipeline_stable_unclip import StableUnCLIPPipeline
|
from .pipeline_stable_unclip import StableUnCLIPPipeline
|
||||||
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
|||||||
@@ -0,0 +1,588 @@
|
|||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image
|
||||||
|
import torch
|
||||||
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
|
from ...image_processor import VaeImageProcessor
|
||||||
|
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
||||||
|
from ...models.attention_processor import (
|
||||||
|
AttnProcessor2_0,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
)
|
||||||
|
from ...schedulers import KarrasDiffusionSchedulers
|
||||||
|
from ...utils import BaseOutput, logging
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def _append_dims(x, target_dims):
|
||||||
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||||
|
dims_to_append = target_dims - x.ndim
|
||||||
|
if dims_to_append < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||||
|
)
|
||||||
|
return x[(...,) + (None,) * dims_to_append]
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||||
|
# Based on:
|
||||||
|
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||||
|
|
||||||
|
batch_size, channels, num_frames, height, width = video.shape
|
||||||
|
outputs = []
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||||
|
batch_output = processor.postprocess(batch_vid, output_type)
|
||||||
|
|
||||||
|
outputs.append(batch_output)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StableDiffusionVideoPipelineOutput(BaseOutput):
|
||||||
|
r"""
|
||||||
|
Output class for zero-shot text-to-video pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
|
||||||
|
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
||||||
|
num_channels)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frames: Union[List[PIL.Image.Image], np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionVideoPipeline(DiffusionPipeline):
|
||||||
|
r"""
|
||||||
|
Pipeline to generate video from an input image using Stable Video Diffusion.
|
||||||
|
|
||||||
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||||
|
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vae ([`AutoencoderKL`]):
|
||||||
|
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||||
|
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
|
||||||
|
Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
|
||||||
|
unet ([`UNetSpatioTemporalConditionModel`]):
|
||||||
|
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
|
||||||
|
scheduler ([`SchedulerMixin`]):
|
||||||
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||||
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||||
|
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||||
|
A `CLIPImageProcessor` to extract features from generated images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_cpu_offload_seq = "image_encoder->unet->vae"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae: AutoencoderKLTemporalDecoder,
|
||||||
|
image_encoder: CLIPVisionModelWithProjection,
|
||||||
|
unet: UNetSpatioTemporalConditionModel,
|
||||||
|
scheduler: KarrasDiffusionSchedulers,
|
||||||
|
feature_extractor: CLIPImageProcessor,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
image_encoder=image_encoder,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||||
|
|
||||||
|
def _encode_image(
|
||||||
|
self, image, device, num_videos_per_prompt, do_classifier_free_guidance
|
||||||
|
):
|
||||||
|
dtype = next(self.image_encoder.parameters()).dtype
|
||||||
|
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
image = self.feature_extractor(
|
||||||
|
images=image, return_tensors="pt"
|
||||||
|
).pixel_values
|
||||||
|
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
image_embeddings = self.image_encoder(image).image_embeds
|
||||||
|
image_embeddings = image_embeddings.unsqueeze(1)
|
||||||
|
|
||||||
|
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
||||||
|
bs_embed, seq_len, _ = image_embeddings.shape
|
||||||
|
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
||||||
|
image_embeddings = image_embeddings.view(
|
||||||
|
bs_embed * num_videos_per_prompt, seq_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
||||||
|
|
||||||
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
|
# to avoid doing two forward passes
|
||||||
|
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
||||||
|
|
||||||
|
return image_embeddings
|
||||||
|
|
||||||
|
def _encode_vae_image(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
device,
|
||||||
|
num_videos_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
):
|
||||||
|
image = image.to(device=device)
|
||||||
|
image_latents = self.vae.encode(image).latent_dist.mode()
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
negative_image_latents = torch.zeros_like(image_latents)
|
||||||
|
|
||||||
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
|
# to avoid doing two forward passes
|
||||||
|
image_latents = torch.cat([negative_image_latents, image_latents])
|
||||||
|
|
||||||
|
# duplicate image_latents for each generation per prompt, using mps friendly method
|
||||||
|
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
||||||
|
|
||||||
|
return image_latents
|
||||||
|
|
||||||
|
def _get_add_time_ids(
|
||||||
|
self,
|
||||||
|
fps_id,
|
||||||
|
motion_bucket_id,
|
||||||
|
cond_aug,
|
||||||
|
dtype,
|
||||||
|
batch_size,
|
||||||
|
num_videos_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
):
|
||||||
|
add_time_ids = [fps_id, motion_bucket_id, cond_aug]
|
||||||
|
|
||||||
|
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(
|
||||||
|
add_time_ids
|
||||||
|
)
|
||||||
|
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||||
|
|
||||||
|
if expected_add_embed_dim != passed_add_embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||||
|
)
|
||||||
|
|
||||||
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||||
|
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
add_time_ids = torch.cat([add_time_ids, add_time_ids])
|
||||||
|
|
||||||
|
return add_time_ids
|
||||||
|
|
||||||
|
def decode_latents(self, latents, num_frames, decoding_t=14):
|
||||||
|
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
||||||
|
latents = latents.flatten(0, 1)
|
||||||
|
|
||||||
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
|
|
||||||
|
# decode decoding_t frames at a time to avoid OOM
|
||||||
|
frames = []
|
||||||
|
for i in range(0, latents.shape[0], decoding_t):
|
||||||
|
num_frames_in = latents[i : i + decoding_t].shape[0]
|
||||||
|
frame = self.vae.decode(latents[i : i + decoding_t], num_frames_in).sample
|
||||||
|
frames.append(frame)
|
||||||
|
frames = torch.cat(frames, dim=0)
|
||||||
|
|
||||||
|
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
||||||
|
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(
|
||||||
|
0, 2, 1, 3, 4
|
||||||
|
)
|
||||||
|
|
||||||
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
|
frames = frames.float()
|
||||||
|
return frames
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||||
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
|
# and should be between [0, 1]
|
||||||
|
|
||||||
|
accepts_eta = "eta" in set(
|
||||||
|
inspect.signature(self.scheduler.step).parameters.keys()
|
||||||
|
)
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
if accepts_eta:
|
||||||
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
# check if the scheduler accepts generator
|
||||||
|
accepts_generator = "generator" in set(
|
||||||
|
inspect.signature(self.scheduler.step).parameters.keys()
|
||||||
|
)
|
||||||
|
if accepts_generator:
|
||||||
|
extra_step_kwargs["generator"] = generator
|
||||||
|
return extra_step_kwargs
|
||||||
|
|
||||||
|
def check_inputs(self, image, height, width, callback_steps):
|
||||||
|
if (
|
||||||
|
not isinstance(image, torch.Tensor)
|
||||||
|
and not isinstance(image, PIL.Image.Image)
|
||||||
|
and not isinstance(image, list)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
||||||
|
f" {type(image)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (callback_steps is None) or (
|
||||||
|
callback_steps is not None
|
||||||
|
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||||
|
f" {type(callback_steps)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_latents(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
num_frames,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
num_frames,
|
||||||
|
num_channels_latents // 2,
|
||||||
|
height // self.vae_scale_factor,
|
||||||
|
width // self.vae_scale_factor,
|
||||||
|
)
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(
|
||||||
|
shape, generator=generator, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device)
|
||||||
|
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
latents = latents * self.scheduler.init_noise_sigma
|
||||||
|
return latents
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||||
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||||
|
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||||
|
|
||||||
|
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||||
|
|
||||||
|
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||||
|
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s1 (`float`):
|
||||||
|
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||||
|
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||||
|
s2 (`float`):
|
||||||
|
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||||
|
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||||
|
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||||
|
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "unet"):
|
||||||
|
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||||
|
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||||
|
def disable_freeu(self):
|
||||||
|
"""Disables the FreeU mechanism if enabled."""
|
||||||
|
self.unet.disable_freeu()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
||||||
|
height: int = 576,
|
||||||
|
width: int = 1024,
|
||||||
|
num_frames: int = 14,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
min_guidance_scale: float = 1.0,
|
||||||
|
max_guidance_scale: float = 2.5,
|
||||||
|
fps_id: int = 6,
|
||||||
|
motion_bucket_id: int = 127,
|
||||||
|
cond_aug: int = 0.02,
|
||||||
|
decoding_t: int = 14,
|
||||||
|
num_videos_per_prompt: Optional[int] = 1,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
The call function to the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
||||||
|
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
||||||
|
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
||||||
|
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||||
|
The height in pixels of the generated image.
|
||||||
|
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||||
|
The width in pixels of the generated image.
|
||||||
|
num_frames (`int`, *optional*, defaults to 14):
|
||||||
|
The number of video frames to generate.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference. This parameter is modulated by `strength`.
|
||||||
|
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
||||||
|
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
||||||
|
max_guidance_scale (`float`, *optional*, defaults to 2.5):
|
||||||
|
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
||||||
|
fps_id (`int`, *optional*, defaults to 6):
|
||||||
|
The frame rate ID. Used as conditioning for the generation.
|
||||||
|
motion_bucket_id (`int`, *optional*, defaults to 127):
|
||||||
|
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
||||||
|
cond_aug (`int`, *optional*, defaults to 0.02):
|
||||||
|
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
||||||
|
decoding_t (`int`, *optional*, defaults to 14):
|
||||||
|
The number of frames to decode at a time. This is used to avoid OOM errors.
|
||||||
|
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
|
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||||
|
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||||
|
generation deterministic.
|
||||||
|
latents (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor is generated by sampling using the supplied random `generator`.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||||
|
plain tuple.
|
||||||
|
callback (`Callable`, *optional*):
|
||||||
|
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||||
|
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
|
callback_steps (`int`, *optional*, defaults to 1):
|
||||||
|
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||||
|
every step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionVideoPipelineOutput`] or `tuple`:
|
||||||
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionVideoPipelineOutput`] is returned,
|
||||||
|
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from diffusers import StableDiffusionVideoPipeline
|
||||||
|
from diffusers.utils import load_image, export_to_video
|
||||||
|
|
||||||
|
pipe = StableDiffusionVideoPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16")
|
||||||
|
pipe.to("cuda")
|
||||||
|
|
||||||
|
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
||||||
|
image = image.resize((1024, 576))
|
||||||
|
|
||||||
|
frames = pipe(image, num_frames=14, decoding_t=8).frames[0]
|
||||||
|
export_to_video(frames, "generated.mp4", fps=7)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# 0. Default height and width to unet
|
||||||
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(image, height, width, callback_steps)
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
batch_size = 1
|
||||||
|
elif isinstance(image, list):
|
||||||
|
batch_size = len(image)
|
||||||
|
else:
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
device = self._execution_device
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = max_guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 3. Encode input image
|
||||||
|
image_embeddings = self._encode_image(
|
||||||
|
image, device, num_videos_per_prompt, do_classifier_free_guidance
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Encode input image using VAE
|
||||||
|
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||||
|
image = image + cond_aug * torch.randn_like(image)
|
||||||
|
|
||||||
|
needs_upcasting = (
|
||||||
|
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||||
|
)
|
||||||
|
if needs_upcasting:
|
||||||
|
self.vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
image_latents = self._encode_vae_image(
|
||||||
|
image, device, num_videos_per_prompt, do_classifier_free_guidance
|
||||||
|
)
|
||||||
|
image_latents = image_latents.to(image_embeddings.dtype)
|
||||||
|
|
||||||
|
# cast back to fp16 if needed
|
||||||
|
if needs_upcasting:
|
||||||
|
self.vae.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
# Repeat the image latents for each frame so we can concatenate them with the noise
|
||||||
|
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
||||||
|
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
||||||
|
|
||||||
|
# 5. Get Added Time IDs
|
||||||
|
added_time_ids = self._get_add_time_ids(
|
||||||
|
fps_id,
|
||||||
|
motion_bucket_id,
|
||||||
|
cond_aug,
|
||||||
|
image_embeddings.dtype,
|
||||||
|
batch_size,
|
||||||
|
num_videos_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
)
|
||||||
|
added_time_ids = added_time_ids.to(device)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.unet.config.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_videos_per_prompt,
|
||||||
|
num_frames,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
image_embeddings.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
# 7. Prepare guidance scale
|
||||||
|
guidance_scale = torch.linspace(
|
||||||
|
min_guidance_scale, max_guidance_scale, num_frames
|
||||||
|
).unsqueeze(0)
|
||||||
|
guidance_scale = guidance_scale.to(device, latents.dtype)
|
||||||
|
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
||||||
|
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
||||||
|
|
||||||
|
# 8. Denoising loop
|
||||||
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = (
|
||||||
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
)
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
latent_model_input, t
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate image_latents over channels dimention
|
||||||
|
latent_model_input = torch.cat(
|
||||||
|
[latent_model_input, image_latents], dim=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.unet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=image_embeddings,
|
||||||
|
added_time_ids=added_time_ids,
|
||||||
|
).sample
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_cond - noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(
|
||||||
|
noise_pred, t, latents, **extra_step_kwargs
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
|
progress_bar.update()
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||||
|
callback(step_idx, t, latents)
|
||||||
|
|
||||||
|
frames = self.decode_latents(latents, num_frames, decoding_t)
|
||||||
|
|
||||||
|
# cast back to fp16 if needed
|
||||||
|
if needs_upcasting:
|
||||||
|
self.vae.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
if not output_type == "latent":
|
||||||
|
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
||||||
|
else:
|
||||||
|
frames = latents
|
||||||
|
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return frames
|
||||||
|
|
||||||
|
return StableDiffusionVideoPipelineOutput(frames=frames)
|
||||||
@@ -323,8 +323,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
|
|||||||
@@ -358,8 +358,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
|
|||||||
@@ -358,8 +358,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
|
|||||||
@@ -357,8 +357,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
|
|||||||
@@ -144,7 +144,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
prediction_type: str = "epsilon",
|
prediction_type: str = "epsilon",
|
||||||
interpolation_type: str = "linear",
|
interpolation_type: str = "linear",
|
||||||
use_karras_sigmas: Optional[bool] = False,
|
use_karras_sigmas: Optional[bool] = False,
|
||||||
|
sigma_min: Optional[float] = None,
|
||||||
|
sigma_max: Optional[float] = None,
|
||||||
timestep_spacing: str = "linspace",
|
timestep_spacing: str = "linspace",
|
||||||
|
timestep_type: str = "discrete", # can be "discrete" or "continuous"
|
||||||
steps_offset: int = 0,
|
steps_offset: int = 0,
|
||||||
):
|
):
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
@@ -268,6 +271,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||||
|
|
||||||
|
# when timestep_type is continuous, we need to convert the timesteps to continuous values using c_noise
|
||||||
|
if self.config.timestep_type == "continuous":
|
||||||
|
timesteps = np.array([self.get_scalings(sigma)[-1] for sigma in sigmas])
|
||||||
|
|
||||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||||
|
|
||||||
@@ -301,8 +308,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
@@ -311,6 +330,33 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
|
def get_scalings(self, sigma=None):
|
||||||
|
"""
|
||||||
|
Get the scalings for the current timestep.
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`:
|
||||||
|
The scaling factors for the current timestep.
|
||||||
|
"""
|
||||||
|
if sigma is None:
|
||||||
|
sigma = self.sigmas[self.step_index]
|
||||||
|
|
||||||
|
if self.config.prediction_type == "epsilon":
|
||||||
|
c_skip = torch.ones_like(sigma, device=sigma.device)
|
||||||
|
c_out = -sigma
|
||||||
|
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
||||||
|
c_noise = sigma.clone()
|
||||||
|
elif self.config.prediction_type == "v_prediction":
|
||||||
|
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||||
|
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||||
|
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||||
|
c_noise = 0.25 * math.log(sigma)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||||
|
)
|
||||||
|
|
||||||
|
return c_skip, c_out, c_in, c_noise
|
||||||
|
|
||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
if isinstance(timestep, torch.Tensor):
|
if isinstance(timestep, torch.Tensor):
|
||||||
timestep = timestep.to(self.timesteps.device)
|
timestep = timestep.to(self.timesteps.device)
|
||||||
@@ -391,6 +437,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self._init_step_index(timestep)
|
self._init_step_index(timestep)
|
||||||
|
|
||||||
sigma = self.sigmas[self.step_index]
|
sigma = self.sigmas[self.step_index]
|
||||||
|
c_skip, c_out, _, _ = self.get_scalings()
|
||||||
|
|
||||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||||
|
|
||||||
@@ -412,8 +459,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
elif self.config.prediction_type == "epsilon":
|
elif self.config.prediction_type == "epsilon":
|
||||||
pred_original_sample = sample - sigma_hat * model_output
|
pred_original_sample = sample - sigma_hat * model_output
|
||||||
elif self.config.prediction_type == "v_prediction":
|
elif self.config.prediction_type == "v_prediction":
|
||||||
# * c_out + input * c_skip
|
pred_original_sample = model_output * c_out + sample * c_skip
|
||||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||||
|
|||||||
@@ -303,8 +303,20 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
|
|||||||
@@ -324,8 +324,20 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
|
|||||||
@@ -335,8 +335,20 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
|
|||||||
@@ -337,8 +337,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||||
|
|
||||||
sigma_min: float = in_sigmas[-1].item()
|
# Hack to make sure that other schedulers which copy this function don't break
|
||||||
sigma_max: float = in_sigmas[0].item()
|
# TODO: Add this logic to the other schedulers
|
||||||
|
if hasattr(self.config, "sigma_min"):
|
||||||
|
sigma_min = self.config.sigma_min
|
||||||
|
else:
|
||||||
|
sigma_min = None
|
||||||
|
|
||||||
|
if hasattr(self.config, "sigma_max"):
|
||||||
|
sigma_max = self.config.sigma_max
|
||||||
|
else:
|
||||||
|
sigma_max = None
|
||||||
|
|
||||||
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||||
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||||
|
|
||||||
rho = 7.0 # 7.0 is the value used in the paper
|
rho = 7.0 # 7.0 is the value used in the paper
|
||||||
ramp = np.linspace(0, 1, num_inference_steps)
|
ramp = np.linspace(0, 1, num_inference_steps)
|
||||||
|
|||||||
@@ -32,6 +32,21 @@ class AutoencoderKL(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderTiny(metaclass=DummyObject):
|
class AutoencoderTiny(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -272,6 +287,21 @@ class UNetMotionModel(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class UNetSpatioTemporalConditionModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class VQModel(metaclass=DummyObject):
|
class VQModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -1022,6 +1022,21 @@ class StableDiffusionUpscalePipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers"])
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionVideoPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLAdapterPipeline(metaclass=DummyObject):
|
class StableDiffusionXLAdapterPipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import random
|
|||||||
import struct
|
import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@@ -115,7 +115,9 @@ def export_to_obj(mesh, output_obj_path: str = None):
|
|||||||
f.writelines("\n".join(combined_data))
|
f.writelines("\n".join(combined_data))
|
||||||
|
|
||||||
|
|
||||||
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
|
def export_to_video(
|
||||||
|
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||||
|
) -> str:
|
||||||
if is_opencv_available():
|
if is_opencv_available():
|
||||||
import cv2
|
import cv2
|
||||||
else:
|
else:
|
||||||
@@ -123,9 +125,12 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non
|
|||||||
if output_video_path is None:
|
if output_video_path is None:
|
||||||
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
||||||
|
|
||||||
|
if isinstance(video_frames[0], PIL.Image.Image):
|
||||||
|
video_frames = [np.array(frame) for frame in video_frames]
|
||||||
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||||
h, w, c = video_frames[0].shape
|
h, w, c = video_frames[0].shape
|
||||||
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
|
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
|
||||||
for i in range(len(video_frames)):
|
for i in range(len(video_frames)):
|
||||||
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
|
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
|
||||||
video_writer.write(img)
|
video_writer.write(img)
|
||||||
|
|||||||
354
tests/models/test_models_unet_spatiotemporal.py
Normal file
354
tests/models/test_models_unet_spatiotemporal.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import UNetSpatioTemporalConditionModel
|
||||||
|
from diffusers.utils import logging
|
||||||
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
from diffusers.utils.testing_utils import (
|
||||||
|
enable_full_determinism,
|
||||||
|
floats_tensor,
|
||||||
|
torch_all_close,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
enable_full_determinism()
|
||||||
|
|
||||||
|
|
||||||
|
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||||
|
model_class = UNetSpatioTemporalConditionModel
|
||||||
|
main_input_name = "sample"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_input(self):
|
||||||
|
batch_size = 4
|
||||||
|
num_channels = 4
|
||||||
|
sizes = (32, 32)
|
||||||
|
|
||||||
|
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||||
|
time_step = torch.tensor([10]).to(torch_device)
|
||||||
|
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sample": noise,
|
||||||
|
"timestep": time_step,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
return (4, 32, 32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
return (4, 32, 32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps_id(self):
|
||||||
|
return 6
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motion_bucket_id(self):
|
||||||
|
return 127
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cond_aug(self):
|
||||||
|
return 0.02
|
||||||
|
|
||||||
|
def prepare_init_args_and_inputs_for_common(self):
|
||||||
|
init_dict = {
|
||||||
|
"block_out_channels": (32, 64),
|
||||||
|
"down_block_types": (
|
||||||
|
"CrossAttnDownBlockSpatioTemporal",
|
||||||
|
"DownBlockSpatioTemporal",
|
||||||
|
),
|
||||||
|
"up_block_types": (
|
||||||
|
"UpBlockSpatioTemporal",
|
||||||
|
"CrossAttnUpBlockSpatioTemporal",
|
||||||
|
),
|
||||||
|
"cross_attention_dim": 32,
|
||||||
|
"attention_head_dim": 8,
|
||||||
|
"out_channels": 4,
|
||||||
|
"in_channels": 4,
|
||||||
|
"layers_per_block": 2,
|
||||||
|
"sample_size": 32,
|
||||||
|
}
|
||||||
|
inputs_dict = self.dummy_input
|
||||||
|
return init_dict, inputs_dict
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
torch_device != "cuda" or not is_xformers_available(),
|
||||||
|
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||||
|
)
|
||||||
|
def test_xformers_enable_works(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
|
||||||
|
model.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||||
|
== "XFormersAttnProcessor"
|
||||||
|
), "xformers is not enabled"
|
||||||
|
|
||||||
|
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||||
|
def test_gradient_checkpointing(self):
|
||||||
|
# enable deterministic behavior for gradient checkpointing
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
assert not model.is_gradient_checkpointing and model.training
|
||||||
|
|
||||||
|
out = model(**inputs_dict).sample
|
||||||
|
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||||
|
# we won't calculate the loss and rather backprop on out.sum()
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
labels = torch.randn_like(out)
|
||||||
|
loss = (out - labels).mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# re-instantiate the model now enabling gradient checkpointing
|
||||||
|
model_2 = self.model_class(**init_dict)
|
||||||
|
# clone model
|
||||||
|
model_2.load_state_dict(model.state_dict())
|
||||||
|
model_2.to(torch_device)
|
||||||
|
model_2.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
assert model_2.is_gradient_checkpointing and model_2.training
|
||||||
|
|
||||||
|
out_2 = model_2(**inputs_dict).sample
|
||||||
|
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||||
|
# we won't calculate the loss and rather backprop on out.sum()
|
||||||
|
model_2.zero_grad()
|
||||||
|
loss_2 = (out_2 - labels).mean()
|
||||||
|
loss_2.backward()
|
||||||
|
|
||||||
|
# compare the output and parameters gradients
|
||||||
|
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||||
|
named_params = dict(model.named_parameters())
|
||||||
|
named_params_2 = dict(model_2.named_parameters())
|
||||||
|
for name, param in named_params.items():
|
||||||
|
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||||
|
|
||||||
|
def test_model_with_attention_head_dim_tuple(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["attention_head_dim"] = (8, 16)
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output = output.sample
|
||||||
|
|
||||||
|
self.assertIsNotNone(output)
|
||||||
|
expected_shape = inputs_dict["sample"].shape
|
||||||
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||||
|
|
||||||
|
def test_model_with_use_linear_projection(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["use_linear_projection"] = True
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output = output.sample
|
||||||
|
|
||||||
|
self.assertIsNotNone(output)
|
||||||
|
expected_shape = inputs_dict["sample"].shape
|
||||||
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||||
|
|
||||||
|
def test_model_with_cross_attention_dim_tuple(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["cross_attention_dim"] = (32, 32)
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output = output.sample
|
||||||
|
|
||||||
|
self.assertIsNotNone(output)
|
||||||
|
expected_shape = inputs_dict["sample"].shape
|
||||||
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||||
|
|
||||||
|
def test_model_with_simple_projection(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||||
|
|
||||||
|
init_dict["class_embed_type"] = "simple_projection"
|
||||||
|
init_dict["projection_class_embeddings_input_dim"] = sample_size
|
||||||
|
|
||||||
|
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output = output.sample
|
||||||
|
|
||||||
|
self.assertIsNotNone(output)
|
||||||
|
expected_shape = inputs_dict["sample"].shape
|
||||||
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||||
|
|
||||||
|
def test_model_with_class_embeddings_concat(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||||
|
|
||||||
|
init_dict["class_embed_type"] = "simple_projection"
|
||||||
|
init_dict["projection_class_embeddings_input_dim"] = sample_size
|
||||||
|
init_dict["class_embeddings_concat"] = True
|
||||||
|
|
||||||
|
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output = output.sample
|
||||||
|
|
||||||
|
self.assertIsNotNone(output)
|
||||||
|
expected_shape = inputs_dict["sample"].shape
|
||||||
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||||
|
|
||||||
|
def test_model_attention_slicing(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["attention_head_dim"] = (8, 16)
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.set_attention_slice("auto")
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
assert output is not None
|
||||||
|
|
||||||
|
model.set_attention_slice("max")
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
assert output is not None
|
||||||
|
|
||||||
|
model.set_attention_slice(2)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
assert output is not None
|
||||||
|
|
||||||
|
def test_model_sliceable_head_dim(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["attention_head_dim"] = (8, 16)
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
|
||||||
|
def check_sliceable_dim_attr(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
assert isinstance(module.sliceable_head_dim, int)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
check_sliceable_dim_attr(child)
|
||||||
|
|
||||||
|
# retrieve number of attention layers
|
||||||
|
for module in model.children():
|
||||||
|
check_sliceable_dim_attr(module)
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_is_applied(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["attention_head_dim"] = (8, 16)
|
||||||
|
|
||||||
|
model_class_copy = copy.copy(self.model_class)
|
||||||
|
|
||||||
|
modules_with_gc_enabled = {}
|
||||||
|
|
||||||
|
# now monkey patch the following function:
|
||||||
|
# def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
# if hasattr(module, "gradient_checkpointing"):
|
||||||
|
# module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||||
|
if hasattr(module, "gradient_checkpointing"):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||||
|
|
||||||
|
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||||
|
|
||||||
|
model = model_class_copy(**init_dict)
|
||||||
|
model.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
EXPECTED_SET = {
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"UNetMidBlock2DCrossAttn",
|
||||||
|
"UpBlock2D",
|
||||||
|
"Transformer2DModel",
|
||||||
|
"DownBlock2D",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
|
||||||
|
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||||
|
|
||||||
|
def test_pickle(self):
|
||||||
|
# enable deterministic behavior for gradient checkpointing
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["attention_head_dim"] = (8, 16)
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sample = model(**inputs_dict).sample
|
||||||
|
|
||||||
|
sample_copy = copy.copy(sample)
|
||||||
|
|
||||||
|
assert (sample - sample_copy).abs().max() < 1e-4
|
||||||
@@ -23,6 +23,7 @@ from parameterized import parameterized
|
|||||||
from diffusers import (
|
from diffusers import (
|
||||||
AsymmetricAutoencoderKL,
|
AsymmetricAutoencoderKL,
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
AutoencoderKLTemporalDecoder,
|
||||||
AutoencoderTiny,
|
AutoencoderTiny,
|
||||||
ConsistencyDecoderVAE,
|
ConsistencyDecoderVAE,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
@@ -195,10 +196,16 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
|||||||
named_params = dict(model.named_parameters())
|
named_params = dict(model.named_parameters())
|
||||||
named_params_2 = dict(model_2.named_parameters())
|
named_params_2 = dict(model_2.named_parameters())
|
||||||
for name, param in named_params.items():
|
for name, param in named_params.items():
|
||||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
self.assertTrue(
|
||||||
|
torch_all_close(
|
||||||
|
param.grad.data, named_params_2[name].grad.data, atol=5e-5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_from_pretrained_hub(self):
|
def test_from_pretrained_hub(self):
|
||||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
model, loading_info = AutoencoderKL.from_pretrained(
|
||||||
|
"fusing/autoencoder-kl-dummy", output_loading_info=True
|
||||||
|
)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||||
|
|
||||||
@@ -248,17 +255,39 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
elif torch_device == "cpu":
|
elif torch_device == "cpu":
|
||||||
expected_output_slice = torch.tensor(
|
expected_output_slice = torch.tensor(
|
||||||
[-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
|
[
|
||||||
|
-0.1352,
|
||||||
|
0.0878,
|
||||||
|
0.0419,
|
||||||
|
-0.0818,
|
||||||
|
-0.1069,
|
||||||
|
0.0688,
|
||||||
|
-0.1458,
|
||||||
|
-0.4446,
|
||||||
|
-0.0026,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
expected_output_slice = torch.tensor(
|
expected_output_slice = torch.tensor(
|
||||||
[-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
|
[
|
||||||
|
-0.2421,
|
||||||
|
0.4642,
|
||||||
|
0.2507,
|
||||||
|
-0.0438,
|
||||||
|
0.0682,
|
||||||
|
0.3160,
|
||||||
|
-0.2018,
|
||||||
|
-0.0727,
|
||||||
|
0.2485,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||||
|
|
||||||
|
|
||||||
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
class AsymmetricAutoencoderKLTests(
|
||||||
|
ModelTesterMixin, UNetTesterMixin, unittest.TestCase
|
||||||
|
):
|
||||||
model_class = AsymmetricAutoencoderKL
|
model_class = AsymmetricAutoencoderKL
|
||||||
main_input_name = "sample"
|
main_input_name = "sample"
|
||||||
base_precision = 1e-2
|
base_precision = 1e-2
|
||||||
@@ -336,7 +365,9 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
|||||||
generator = torch.Generator("cpu")
|
generator = torch.Generator("cpu")
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
generator.manual_seed(0)
|
generator.manual_seed(0)
|
||||||
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
|
image = randn_tensor(
|
||||||
|
(4, 3, 32, 32), generator=generator, device=torch.device(torch_device)
|
||||||
|
)
|
||||||
|
|
||||||
return {"sample": image, "generator": generator}
|
return {"sample": image, "generator": generator}
|
||||||
|
|
||||||
@@ -364,6 +395,98 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class AutoncoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase):
|
||||||
|
model_class = AutoencoderKLTemporalDecoder
|
||||||
|
main_input_name = "sample"
|
||||||
|
base_precision = 1e-2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_input(self):
|
||||||
|
batch_size = 3
|
||||||
|
num_channels = 3
|
||||||
|
sizes = (32, 32)
|
||||||
|
|
||||||
|
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||||
|
num_frames = 3
|
||||||
|
|
||||||
|
return {"sample": image, "num_frames": num_frames}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
return (3, 32, 32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
return (3, 32, 32)
|
||||||
|
|
||||||
|
def prepare_init_args_and_inputs_for_common(self):
|
||||||
|
init_dict = {
|
||||||
|
"block_out_channels": [8, 16],
|
||||||
|
"in_channels": 3,
|
||||||
|
"out_channels": 3,
|
||||||
|
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||||
|
"latent_channels": 4,
|
||||||
|
"norm_num_groups": 4,
|
||||||
|
"layers_per_block": 2,
|
||||||
|
}
|
||||||
|
inputs_dict = self.dummy_input
|
||||||
|
return init_dict, inputs_dict
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_training(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||||
|
def test_gradient_checkpointing(self):
|
||||||
|
# enable deterministic behavior for gradient checkpointing
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
assert not model.is_gradient_checkpointing and model.training
|
||||||
|
|
||||||
|
out = model(**inputs_dict).sample
|
||||||
|
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||||
|
# we won't calculate the loss and rather backprop on out.sum()
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
labels = torch.randn_like(out)
|
||||||
|
loss = (out - labels).mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# re-instantiate the model now enabling gradient checkpointing
|
||||||
|
model_2 = self.model_class(**init_dict)
|
||||||
|
# clone model
|
||||||
|
model_2.load_state_dict(model.state_dict())
|
||||||
|
model_2.to(torch_device)
|
||||||
|
model_2.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
assert model_2.is_gradient_checkpointing and model_2.training
|
||||||
|
|
||||||
|
out_2 = model_2(**inputs_dict).sample
|
||||||
|
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||||
|
# we won't calculate the loss and rather backprop on out.sum()
|
||||||
|
model_2.zero_grad()
|
||||||
|
loss_2 = (out_2 - labels).mean()
|
||||||
|
loss_2.backward()
|
||||||
|
|
||||||
|
# compare the output and parameters gradients
|
||||||
|
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||||
|
named_params = dict(model.named_parameters())
|
||||||
|
named_params_2 = dict(model_2.named_parameters())
|
||||||
|
for name, param in named_params.items():
|
||||||
|
if "post_quant_conv" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
torch_all_close(
|
||||||
|
param.grad.data, named_params_2[name].grad.data, atol=5e-5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
@@ -377,10 +500,16 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||||
dtype = torch.float16 if fp16 else torch.float32
|
dtype = torch.float16 if fp16 else torch.float32
|
||||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
image = (
|
||||||
|
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
|
||||||
|
.to(torch_device)
|
||||||
|
.to(dtype)
|
||||||
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
|
def get_sd_vae_model(
|
||||||
|
self, model_id="hf-internal-testing/taesd-diffusers", fp16=False
|
||||||
|
):
|
||||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||||
|
|
||||||
model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
|
model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
|
||||||
@@ -414,7 +543,9 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
|||||||
assert sample.shape == image.shape
|
assert sample.shape == image.shape
|
||||||
|
|
||||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||||
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
|
expected_output_slice = torch.tensor(
|
||||||
|
[0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382]
|
||||||
|
)
|
||||||
|
|
||||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||||
|
|
||||||
@@ -454,7 +585,11 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||||
dtype = torch.float16 if fp16 else torch.float32
|
dtype = torch.float16 if fp16 else torch.float32
|
||||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
image = (
|
||||||
|
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
|
||||||
|
.to(torch_device)
|
||||||
|
.to(dtype)
|
||||||
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
|
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
|
||||||
@@ -503,7 +638,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
assert sample.shape == image.shape
|
assert sample.shape == image.shape
|
||||||
|
|
||||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
expected_output_slice = torch.tensor(
|
||||||
|
expected_slice_mps if torch_device == "mps" else expected_slice
|
||||||
|
)
|
||||||
|
|
||||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||||
|
|
||||||
@@ -557,7 +694,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
assert sample.shape == image.shape
|
assert sample.shape == image.shape
|
||||||
|
|
||||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
expected_output_slice = torch.tensor(
|
||||||
|
expected_slice_mps if torch_device == "mps" else expected_slice
|
||||||
|
)
|
||||||
|
|
||||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||||
|
|
||||||
@@ -609,7 +748,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@parameterized.expand([(13,), (16,), (27,)])
|
@parameterized.expand([(13,), (16,), (27,)])
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
@unittest.skipIf(
|
||||||
|
not is_xformers_available(),
|
||||||
|
reason="xformers is not required when using PyTorch 2.0.",
|
||||||
|
)
|
||||||
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
|
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
|
||||||
model = self.get_sd_vae_model(fp16=True)
|
model = self.get_sd_vae_model(fp16=True)
|
||||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||||
@@ -627,7 +769,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@parameterized.expand([(13,), (16,), (37,)])
|
@parameterized.expand([(13,), (16,), (37,)])
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
@unittest.skipIf(
|
||||||
|
not is_xformers_available(),
|
||||||
|
reason="xformers is not required when using PyTorch 2.0.",
|
||||||
|
)
|
||||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||||
model = self.get_sd_vae_model()
|
model = self.get_sd_vae_model()
|
||||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||||
@@ -660,7 +805,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
dist = model.encode(image).latent_dist
|
dist = model.encode(image).latent_dist
|
||||||
sample = dist.sample(generator=generator)
|
sample = dist.sample(generator=generator)
|
||||||
|
|
||||||
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
|
assert list(sample.shape) == [image.shape[0], 4] + [
|
||||||
|
i // 8 for i in image.shape[2:]
|
||||||
|
]
|
||||||
|
|
||||||
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
||||||
expected_output_slice = torch.tensor(expected_slice)
|
expected_output_slice = torch.tensor(expected_slice)
|
||||||
@@ -701,10 +848,16 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||||
dtype = torch.float16 if fp16 else torch.float32
|
dtype = torch.float16 if fp16 else torch.float32
|
||||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
image = (
|
||||||
|
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
|
||||||
|
.to(torch_device)
|
||||||
|
.to(dtype)
|
||||||
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
|
def get_sd_vae_model(
|
||||||
|
self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False
|
||||||
|
):
|
||||||
revision = "main"
|
revision = "main"
|
||||||
torch_dtype = torch.float32
|
torch_dtype = torch.float32
|
||||||
|
|
||||||
@@ -749,7 +902,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
assert sample.shape == image.shape
|
assert sample.shape == image.shape
|
||||||
|
|
||||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
expected_output_slice = torch.tensor(
|
||||||
|
expected_slice_mps if torch_device == "mps" else expected_slice
|
||||||
|
)
|
||||||
|
|
||||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||||
|
|
||||||
@@ -779,7 +934,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
assert sample.shape == image.shape
|
assert sample.shape == image.shape
|
||||||
|
|
||||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
expected_output_slice = torch.tensor(
|
||||||
|
expected_slice_mps if torch_device == "mps" else expected_slice
|
||||||
|
)
|
||||||
|
|
||||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||||
|
|
||||||
@@ -808,7 +965,10 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@parameterized.expand([(13,), (16,), (37,)])
|
@parameterized.expand([(13,), (16,), (37,)])
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
@unittest.skipIf(
|
||||||
|
not is_xformers_available(),
|
||||||
|
reason="xformers is not required when using PyTorch 2.0.",
|
||||||
|
)
|
||||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||||
model = self.get_sd_vae_model()
|
model = self.get_sd_vae_model()
|
||||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||||
@@ -841,7 +1001,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
dist = model.encode(image).latent_dist
|
dist = model.encode(image).latent_dist
|
||||||
sample = dist.sample(generator=generator)
|
sample = dist.sample(generator=generator)
|
||||||
|
|
||||||
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
|
assert list(sample.shape) == [image.shape[0], 4] + [
|
||||||
|
i // 8 for i in image.shape[2:]
|
||||||
|
]
|
||||||
|
|
||||||
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
||||||
expected_output_slice = torch.tensor(expected_slice)
|
expected_output_slice = torch.tensor(expected_slice)
|
||||||
@@ -860,37 +1022,52 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_encode_decode(self):
|
def test_encode_decode(self):
|
||||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
|
vae = ConsistencyDecoderVAE.from_pretrained(
|
||||||
|
"openai/consistency-decoder"
|
||||||
|
) # TODO - update
|
||||||
vae.to(torch_device)
|
vae.to(torch_device)
|
||||||
|
|
||||||
image = load_image(
|
image = load_image(
|
||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||||
"/img2img/sketch-mountains-input.jpg"
|
"/img2img/sketch-mountains-input.jpg"
|
||||||
).resize((256, 256))
|
).resize((256, 256))
|
||||||
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
|
image = torch.from_numpy(
|
||||||
None, :, :, :
|
np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1
|
||||||
].cuda()
|
)[None, :, :, :].cuda()
|
||||||
|
|
||||||
latent = vae.encode(image).latent_dist.mean
|
latent = vae.encode(image).latent_dist.mean
|
||||||
|
|
||||||
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
|
sample = vae.decode(
|
||||||
|
latent, generator=torch.Generator("cpu").manual_seed(0)
|
||||||
|
).sample
|
||||||
|
|
||||||
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
||||||
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
|
expected_output = torch.tensor(
|
||||||
|
[-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024]
|
||||||
|
)
|
||||||
|
|
||||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||||
|
|
||||||
def test_sd(self):
|
def test_sd(self):
|
||||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
|
vae = ConsistencyDecoderVAE.from_pretrained(
|
||||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
|
"openai/consistency-decoder"
|
||||||
|
) # TODO - update
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None
|
||||||
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
|
|
||||||
out = pipe(
|
out = pipe(
|
||||||
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
|
"horse",
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="pt",
|
||||||
|
generator=torch.Generator("cpu").manual_seed(0),
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
actual_output = out[:2, :2, :2].flatten().cpu()
|
actual_output = out[:2, :2, :2].flatten().cpu()
|
||||||
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
|
expected_output = torch.tensor(
|
||||||
|
[0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759]
|
||||||
|
)
|
||||||
|
|
||||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||||
|
|
||||||
@@ -905,18 +1082,23 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
|||||||
"/img2img/sketch-mountains-input.jpg"
|
"/img2img/sketch-mountains-input.jpg"
|
||||||
).resize((256, 256))
|
).resize((256, 256))
|
||||||
image = (
|
image = (
|
||||||
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
|
torch.from_numpy(
|
||||||
|
np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1
|
||||||
|
)[None, :, :, :]
|
||||||
.half()
|
.half()
|
||||||
.cuda()
|
.cuda()
|
||||||
)
|
)
|
||||||
|
|
||||||
latent = vae.encode(image).latent_dist.mean
|
latent = vae.encode(image).latent_dist.mean
|
||||||
|
|
||||||
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
|
sample = vae.decode(
|
||||||
|
latent, generator=torch.Generator("cpu").manual_seed(0)
|
||||||
|
).sample
|
||||||
|
|
||||||
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
||||||
expected_output = torch.tensor(
|
expected_output = torch.tensor(
|
||||||
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
|
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
|
||||||
|
dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||||
@@ -926,17 +1108,24 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
|||||||
"openai/consistency-decoder", torch_dtype=torch.float16
|
"openai/consistency-decoder", torch_dtype=torch.float16
|
||||||
) # TODO - update
|
) # TODO - update
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
|
"runwayml/stable-diffusion-v1-5",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
vae=vae,
|
||||||
|
safety_checker=None,
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
|
|
||||||
out = pipe(
|
out = pipe(
|
||||||
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
|
"horse",
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="pt",
|
||||||
|
generator=torch.Generator("cpu").manual_seed(0),
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
actual_output = out[:2, :2, :2].flatten().cpu()
|
actual_output = out[:2, :2, :2].flatten().cpu()
|
||||||
expected_output = torch.tensor(
|
expected_output = torch.tensor(
|
||||||
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
|
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
|
||||||
|
dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||||
|
|||||||
261
tests/pipelines/stable_diffusion/test_stable_diffusion_video.py
Normal file
261
tests/pipelines/stable_diffusion/test_stable_diffusion_video.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
import gc
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
CLIPImageProcessor,
|
||||||
|
CLIPTextConfig,
|
||||||
|
CLIPTextModel,
|
||||||
|
CLIPTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKLTemporalDecoder,
|
||||||
|
DDIMScheduler,
|
||||||
|
StableDiffusionVideoPipeline,
|
||||||
|
UNetSpatioTemporalConditionModel,
|
||||||
|
)
|
||||||
|
from diffusers.utils import load_image, logging
|
||||||
|
from diffusers.utils.testing_utils import (
|
||||||
|
floats_tensor,
|
||||||
|
numpy_cosine_similarity_distance,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||||
|
from ..test_pipelines_common import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
def to_np(tensor):
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
tensor = tensor.detach().cpu().numpy()
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
pipeline_class = StableDiffusionVideoPipeline
|
||||||
|
params = TEXT_TO_IMAGE_PARAMS
|
||||||
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||||
|
required_optional_params = frozenset(
|
||||||
|
[
|
||||||
|
"num_inference_steps",
|
||||||
|
"generator",
|
||||||
|
"latents",
|
||||||
|
"return_dict",
|
||||||
|
"callback",
|
||||||
|
"callback_steps",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_dummy_components(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
unet = UNetSpatioTemporalConditionModel(
|
||||||
|
block_out_channels=(32, 64),
|
||||||
|
layers_per_block=2,
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=(
|
||||||
|
"CrossAttnDownBlockSpatioTemporal",
|
||||||
|
"DownBlockSpatioTemporal",
|
||||||
|
),
|
||||||
|
up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"),
|
||||||
|
cross_attention_dim=32,
|
||||||
|
norm_num_groups=2,
|
||||||
|
)
|
||||||
|
scheduler = DDIMScheduler(
|
||||||
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
beta_schedule="linear",
|
||||||
|
clip_sample=False,
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vae = AutoencoderKLTemporalDecoder(
|
||||||
|
block_out_channels=[32, 64],
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||||
|
latent_channels=4,
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
text_encoder_config = CLIPTextConfig(
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=37,
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
pad_token_id=1,
|
||||||
|
vocab_size=1000,
|
||||||
|
)
|
||||||
|
text_encoder = CLIPTextModel(text_encoder_config)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
|
||||||
|
|
||||||
|
components = {
|
||||||
|
"unet": unet,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"vae": vae,
|
||||||
|
"text_encoder": text_encoder,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"feature_extractor": feature_extractor,
|
||||||
|
}
|
||||||
|
return components
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
|
if str(device).startswith("mps"):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||||
|
image = image / 2 + 0.5
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
"generator": generator,
|
||||||
|
"image": image,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 7.5,
|
||||||
|
"output_type": "pt",
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@unittest.skip("Attention slicing is not enabled in this pipeline")
|
||||||
|
def test_attention_slicing_forward_pass(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_inference_batch_single_identical(
|
||||||
|
self,
|
||||||
|
batch_size=2,
|
||||||
|
expected_max_diff=1e-4,
|
||||||
|
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||||
|
):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
for components in pipe.components.values():
|
||||||
|
if hasattr(components, "set_default_attn_processor"):
|
||||||
|
components.set_default_attn_processor()
|
||||||
|
|
||||||
|
pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||||
|
inputs["generator"] = self.get_generator(0)
|
||||||
|
|
||||||
|
logger = logging.get_logger(pipe.__module__)
|
||||||
|
logger.setLevel(level=diffusers.logging.FATAL)
|
||||||
|
|
||||||
|
# batchify inputs
|
||||||
|
batched_inputs = {}
|
||||||
|
batched_inputs.update(inputs)
|
||||||
|
|
||||||
|
for name in self.batch_params:
|
||||||
|
if name not in inputs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = inputs[name]
|
||||||
|
if name == "prompt":
|
||||||
|
len_prompt = len(value)
|
||||||
|
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||||
|
batched_inputs[name][-1] = 100 * "very long"
|
||||||
|
|
||||||
|
else:
|
||||||
|
batched_inputs[name] = batch_size * [value]
|
||||||
|
|
||||||
|
if "generator" in inputs:
|
||||||
|
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||||
|
|
||||||
|
if "batch_size" in inputs:
|
||||||
|
batched_inputs["batch_size"] = batch_size
|
||||||
|
|
||||||
|
for arg in additional_params_copy_to_batched_inputs:
|
||||||
|
batched_inputs[arg] = inputs[arg]
|
||||||
|
|
||||||
|
output = pipe(**inputs)
|
||||||
|
output_batch = pipe(**batched_inputs)
|
||||||
|
|
||||||
|
assert output_batch[0].shape[0] == batch_size
|
||||||
|
|
||||||
|
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
|
||||||
|
assert max_diff < expected_max_diff
|
||||||
|
|
||||||
|
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
|
||||||
|
def test_to_device(self):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
pipe.to("cpu")
|
||||||
|
model_devices = [
|
||||||
|
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||||
|
]
|
||||||
|
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||||
|
|
||||||
|
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||||
|
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||||
|
|
||||||
|
pipe.to("cuda")
|
||||||
|
model_devices = [
|
||||||
|
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||||
|
]
|
||||||
|
self.assertTrue(all(device == "cuda" for device in model_devices))
|
||||||
|
|
||||||
|
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
|
||||||
|
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
|
||||||
|
|
||||||
|
def test_to_dtype(self):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
|
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||||
|
|
||||||
|
pipe.to(torch_dtype=torch.float16)
|
||||||
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
|
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
class StableDiffusionVideoPipelineSlowTests(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
# clean up the VRAM after each test
|
||||||
|
super().tearDown()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def test_sd_video(self):
|
||||||
|
pipe = StableDiffusionVideoPipeline.from_pretrained("diffusers/svd-test")
|
||||||
|
pipe = pipe.to(torch_device)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
image = load_image(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = torch.Generator("cpu").manual_seed(0)
|
||||||
|
num_frames = 3
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
image=image,
|
||||||
|
num_frames=num_frames,
|
||||||
|
generator=generator,
|
||||||
|
num_inference_steps=3,
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output.frames[0]
|
||||||
|
assert image.shape == (num_frames, 576, 1024, 3)
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
expected_slice = np.array([0.8592, 0.8645, 0.8499, 0.8722, 0.8769, 0.8421, 0.8557, 0.8528, 0.8285])
|
||||||
|
assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3
|
||||||
Reference in New Issue
Block a user