Compare commits

...

114 Commits

Author SHA1 Message Date
Dhruv Nair
56e8fca572 Merge branch 'main' into test-v 2023-11-27 13:36:38 +00:00
Dhruv Nair
c5941a26a4 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 13:35:36 +00:00
Dhruv Nair
8bc42512fe remove post quant conv 2023-11-27 13:27:46 +00:00
patil-suraj
55b4d09080 fix upcasting 2023-11-27 14:11:26 +01:00
patil-suraj
c452d9c042 up 2023-11-27 13:59:30 +01:00
patil-suraj
ee9f7d2493 make added_time_ids is tensor 2023-11-27 13:55:02 +01:00
Dhruv Nair
8620851aa0 update forward pass for gradient checkpointing 2023-11-27 12:50:58 +00:00
patil-suraj
90d8e832f8 upcast vae 2023-11-27 13:50:10 +01:00
patil-suraj
18930e0b85 doc 2023-11-27 13:40:30 +01:00
patil-suraj
847bd0a479 fix copies 2023-11-27 13:23:31 +01:00
Dhruv Nair
3178b16b17 update 2023-11-27 11:37:52 +00:00
patil-suraj
a08ef009d1 use math for log 2023-11-27 12:16:02 +01:00
patil-suraj
804bdebe51 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 12:01:11 +01:00
patil-suraj
a193e49dff use c_noise values for timesteps 2023-11-27 12:01:08 +01:00
Dhruv Nair
c9d1727613 clean up 2023-11-27 11:00:02 +00:00
Dhruv Nair
82cf60828f Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 10:50:12 +00:00
Dhruv Nair
26ed460265 clean up 2023-11-27 10:49:58 +00:00
Dhruv Nair
403a81c30d clean up temp decoder 2023-11-27 10:21:22 +00:00
patil-suraj
1b3cf2db5e Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 11:13:20 +01:00
patil-suraj
b8d84c4320 fix norm eps in TransformerSpatioTemporalModel 2023-11-27 11:13:18 +01:00
Dhruv Nair
3fbe123d84 make temb optional in Decoder mid block 2023-11-27 10:09:41 +00:00
Dhruv Nair
f7cf8c338c clean up 2023-11-27 09:53:56 +00:00
Dhruv Nair
ab8076f234 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 09:50:00 +00:00
Dhruv Nair
7b6a0d48c6 add slow svd test 2023-11-27 09:45:00 +00:00
patil-suraj
6adae54046 clean TransformerSpatioTemporalModel 2023-11-27 10:34:44 +01:00
patil-suraj
af85fb1bc1 clean up unet 2023-11-27 10:03:40 +01:00
Dhruv Nair
760333d524 add unet tests 2023-11-27 08:12:02 +00:00
patil-suraj
f651c12ef8 don't scale image latents 2023-11-26 17:13:04 +01:00
patil-suraj
d614a33a09 use AutoencoderKLTemporalDecoder 2023-11-26 17:00:22 +01:00
patil-suraj
13b646edd3 remove hack 2023-11-26 16:59:21 +01:00
patil-suraj
cb49cbdd29 add pipeline and vae in init 2023-11-26 16:58:59 +01:00
patil-suraj
1ce8ff51e6 accept fps as arg 2023-11-26 16:20:22 +01:00
patil-suraj
fdd182f335 allow passing PIL to export_video 2023-11-26 16:19:25 +01:00
patil-suraj
2a46326c25 up 2023-11-26 16:07:24 +01:00
patil-suraj
e34e9d9a33 take guidance scale as input 2023-11-26 16:06:44 +01:00
patil-suraj
96af28f92b style 2023-11-26 16:01:32 +01:00
patil-suraj
6827a1dc6a add vae conversion 2023-11-26 15:42:27 +01:00
patil-suraj
c3bdeb8a4c skip_post_quant_conv 2023-11-26 13:07:50 +01:00
patil-suraj
cf70b9a0b4 fix missing activation in TemporalDecoder 2023-11-26 13:06:44 +01:00
patil-suraj
712b9950c5 fix guidance_scales dtype 2023-11-26 12:47:51 +01:00
patil-suraj
21148de853 fix typo 2023-11-26 12:45:01 +01:00
patil-suraj
d930977656 fix attention in MidBlockTemporalDecoder 2023-11-26 12:01:14 +01:00
patil-suraj
268ffea0e7 cast alpha to sample dtype 2023-11-26 11:15:28 +01:00
patil-suraj
8bcf43d52a fix num frames during split decoding 2023-11-26 11:10:42 +01:00
patil-suraj
b071aaa719 switch spatial to temporal for mixing in VAE 2023-11-26 10:51:53 +01:00
patil-suraj
5316fb5107 pass num frames in decode 2023-11-25 19:15:19 +01:00
patil-suraj
9af07d1d5c fix default values in vae 2023-11-25 19:09:47 +01:00
patil-suraj
d0017d9b70 allow using differnt eps in temporal block for video decoder 2023-11-25 19:02:57 +01:00
patil-suraj
0cf6c6b291 type image_latents same as image_embeddings 2023-11-25 16:20:01 +01:00
patil-suraj
df986274d6 fix dtype in TransformerSpatioTemporalModel 2023-11-25 16:17:45 +01:00
patil-suraj
7ddd14bd94 vae encode/decode in fp32 2023-11-25 16:16:01 +01:00
patil-suraj
4346ddd402 fix decode_latents 2023-11-25 14:33:25 +01:00
patil-suraj
9da55b381c pass decoding_t to decode_latents 2023-11-25 14:30:27 +01:00
patil-suraj
4d4469ee87 decode n frames at a time 2023-11-25 14:30:09 +01:00
patil-suraj
f9954a0e7b decode in float32 2023-11-25 14:02:23 +01:00
patil-suraj
e7798333c4 fix frame decodig 2023-11-25 14:01:01 +01:00
patil-suraj
efb1e5e1d8 make pipeline run 2023-11-24 21:30:31 +01:00
Dhruv Nair
beaaf18b2c Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 16:36:06 +00:00
Dhruv Nair
132fe97bf4 add temporal autoencoder 2023-11-24 16:35:41 +00:00
patil-suraj
2f35e8c94c fix norm eps in temporal transformers 2023-11-24 15:40:41 +01:00
patil-suraj
b336529573 add guidance scalings 2023-11-24 14:16:50 +01:00
patil-suraj
3e47d3c8ed adapt scheduler 2023-11-24 14:06:07 +01:00
patil-suraj
122a6bd390 begin pipeline 2023-11-24 13:36:57 +01:00
Dhruv Nair
37c428a79c Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 12:24:57 +00:00
Dhruv Nair
eefed8ab6b update up/mid blocks for decoder 2023-11-24 12:23:14 +00:00
Dhruv Nair
05eaec2d39 Merge branch 'test-v-old' into test-v 2023-11-24 12:19:29 +00:00
Dhruv Nair
e68424378f update vae 2023-11-24 12:19:11 +00:00
patil-suraj
24b5c4360c check for None 2023-11-24 11:53:50 +01:00
patil-suraj
0c4192b537 up 2023-11-24 11:51:40 +01:00
patil-suraj
dff26ce8af up 2023-11-24 11:50:02 +01:00
patil-suraj
9f22651c1f remove more unsed args 2023-11-24 11:48:58 +01:00
patil-suraj
d8c9e67aac remove unused arg 2023-11-24 11:38:34 +01:00
patil-suraj
6c28367b1a remove unused arg 2023-11-24 11:36:01 +01:00
patil-suraj
f9def2aeed add in init 2023-11-24 11:31:30 +01:00
patil-suraj
576fa1c7dc remove UNetMidBlockSpatioTemporal 2023-11-24 11:30:35 +01:00
patil-suraj
f1457b7e1d update conversion script 2023-11-24 11:24:42 +01:00
patil-suraj
1f34311eec rename model 2023-11-24 11:24:34 +01:00
patil-suraj
f976f5a31e Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 11:17:55 +01:00
patil-suraj
8e1851a16a Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 11:17:51 +01:00
patil-suraj
6c69c7a0d2 add blocks 2023-11-24 11:11:15 +01:00
Dhruv Nair
6481e9495f make temb optional 2023-11-24 10:10:09 +00:00
Dhruv Nair
8c3fd58c85 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 09:51:43 +00:00
Dhruv Nair
9117547ee0 clean up 2023-11-24 09:51:29 +00:00
patil-suraj
af1e86af8d fix time_context dim 2023-11-24 10:47:44 +01:00
patil-suraj
29551f8e30 fix TransformerSpatioTemporalModel 2023-11-24 10:19:44 +01:00
patil-suraj
661033171b use TransformerSpatioTemporalModel 2023-11-24 10:16:22 +01:00
patil-suraj
20efe541c5 fix TemporalBasicTransformerBlock 2023-11-24 10:11:40 +01:00
patil-suraj
5a523e21c6 reuse TemporalBasicTransformerBlock 2023-11-24 10:04:22 +01:00
patil-suraj
b0fc4fd4cb fix SpatioTemporalResBlock 2023-11-24 10:01:09 +01:00
patil-suraj
678d19fa18 fix temb shape 2023-11-24 09:41:15 +01:00
patil-suraj
c8ec445964 style 2023-11-24 09:34:53 +01:00
patil-suraj
ffd9e26a65 use new blocks 2023-11-24 09:26:42 +01:00
patil-suraj
6f87490408 fix shapes in Alphablender and add time activation in res blcok 2023-11-24 08:57:28 +01:00
Dhruv Nair
9c9d46763b update 2023-11-24 07:12:50 +00:00
Dhruv Nair
47684dab43 update 2023-11-24 04:14:58 +00:00
Dhruv Nair
5218f46173 fix blocks 2023-11-23 14:32:18 +00:00
Dhruv Nair
8ee280773f add vae blocks 2023-11-23 14:28:07 +00:00
Dhruv Nair
85846f7450 add spatio temporal transformers 2023-11-23 13:02:34 +00:00
patil-suraj
28dee6e735 fix temb shape in TemporalResnetBlock 2023-11-23 13:52:48 +01:00
patil-suraj
165ed7c5d5 return sample in original shape 2023-11-23 13:52:40 +01:00
patil-suraj
d4cdfa33f5 make forward work 2023-11-23 13:35:52 +01:00
Dhruv Nair
1bd09b1489 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-23 10:54:08 +00:00
Dhruv Nair
edf7121ec7 add new resnet blocks 2023-11-23 10:53:25 +00:00
patil-suraj
7b64d3a17b up 2023-11-23 10:48:59 +01:00
patil-suraj
c93606c93c fix model 2023-11-23 10:47:57 +01:00
patil-suraj
5df09ef355 add conversion script 2023-11-22 19:15:18 +01:00
patil-suraj
ac9473153c fix add_embedding 2023-11-22 19:04:10 +01:00
patil-suraj
ee9d7b8ecd fix time_pos_embed 2023-11-22 18:59:44 +01:00
patil-suraj
669824e5bb fix temporal res block 2023-11-22 17:44:56 +01:00
patil-suraj
45c9b56bf7 use TimestepEmbedding 2023-11-22 15:56:09 +01:00
patil-suraj
cad51d45d1 addition_time_embed_dim 2023-11-22 14:26:43 +01:00
patil-suraj
7de5d7c6fd add_embedding 2023-11-22 14:06:50 +01:00
patil-suraj
58883ee085 finish blocks 2023-11-22 13:42:10 +01:00
patil-suraj
2f5648177e begin model 2023-11-21 16:39:15 +01:00
28 changed files with 6011 additions and 149 deletions

View File

@@ -0,0 +1,730 @@
from diffusers.utils import is_accelerate_available, logging
if is_accelerate_available():
pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
else:
unet_params = original_config.model.params.network_config.params
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnDownBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
else "DownBlockSpatioTemporal"
)
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnUpBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
else "UpBlockSpatioTemporal"
)
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
transformer_layers_per_block = (
unet_params.transformer_depth
if isinstance(unet_params.transformer_depth, int)
else list(unet_params.transformer_depth)
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
class_embed_type = None
addition_embed_type = None
addition_time_embed_dim = None
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
)
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
addition_time_embed_dim = 256
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"addition_embed_type": addition_embed_type,
"addition_time_embed_dim": addition_time_embed_dim,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"transformer_layers_per_block": transformer_layers_per_block,
}
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
config["out_channels"] = unet_params.out_channels
config["up_block_types"] = tuple(up_block_types)
return config
def assign_to_checkpoint(
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
mid_block_suffix="",
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
if mid_block_suffix is not None:
mid_block_suffix = f".{mid_block_suffix}"
else:
mid_block_suffix = ""
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight":
print("yeyy")
# proj_attn.weight has to be converted from conv 1D to linear
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
shape = old_checkpoint[path["old"]].shape
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
new_item = new_item.replace("time_stack", "temporal_transformer_blocks")
new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias")
new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight")
new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias")
new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight")
mapping.append({"old": old_item, "new": new_item})
return mapping
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = new_item.replace("time_stack.", "")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
if skip_extract_state_dict:
unet_state_dict = checkpoint
else:
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
if config["class_embed_type"] is None:
# No parameters to port
...
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
else:
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
# if config["addition_embed_type"] == "text_time":
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
spatial_resnets = [
key
for key in input_blocks[i]
if f"input_blocks.{i}.0" in key
and (
f"input_blocks.{i}.0.op" not in key
and f"input_blocks.{i}.0.time_stack" not in key
and f"input_blocks.{i}.0.time_mixer" not in key
)
]
temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key]
# import ipdb; ipdb.set_trace()
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(spatial_resnets)
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
paths = renew_resnet_paths(temporal_resnets)
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
# import ipdb; ipdb.set_trace()
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key]
resnet_0_paths = renew_resnet_paths(resnet_0_spatial)
# import ipdb; ipdb.set_trace()
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
)
resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key]
resnet_0_paths = renew_resnet_paths(resnet_0_temporal)
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
)
resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key]
resnet_1_paths = renew_resnet_paths(resnet_1_spatial)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
)
resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key]
resnet_1_paths = renew_resnet_paths(resnet_1_temporal)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
)
new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
"middle_block.0.time_mixer.mix_factor"
]
new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[
"middle_block.2.time_mixer.mix_factor"
]
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
spatial_resnets = [
key
for key in output_blocks[i]
if f"output_blocks.{i}.0" in key
and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key)
]
# import ipdb; ipdb.set_trace()
temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key]
paths = renew_resnet_paths(spatial_resnets)
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
paths = renew_resnet_paths(temporal_resnets)
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key]
if len(attentions):
paths = renew_attention_paths(attentions)
# import ipdb; ipdb.set_trace()
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
spatial_layers = [
layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer
]
resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1)
# import ipdb; ipdb.set_trace()
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
temporal_layers = [
layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key
]
resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1)
# import ipdb; ipdb.set_trace()
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
f"output_blocks.{str(i)}.0.time_mixer.mix_factor"
]
return new_checkpoint
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# Temporal resnet
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = new_item.replace("time_stack.", "temporal_res_block.")
# Spatial resnet
new_item = new_item.replace("conv1", "spatial_res_block.conv1")
new_item = new_item.replace("norm1", "spatial_res_block.norm1")
new_item = new_item.replace("conv2", "spatial_res_block.conv2")
new_item = new_item.replace("norm2", "spatial_res_block.norm2")
new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut")
new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
keys = list(checkpoint.keys())
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"]
new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"]
# new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
# new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
# new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
# new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint

View File

@@ -76,6 +76,7 @@ else:
[
"AsymmetricAutoencoderKL",
"AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
@@ -92,6 +93,7 @@ else:
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"VQModel",
]
)
@@ -267,6 +269,7 @@ else:
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionVideoPipeline",
"StableDiffusionXLAdapterPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
@@ -446,6 +449,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .models import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
@@ -462,6 +466,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
VQModel,
)
from .optimization import (
@@ -616,6 +621,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionVideoPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,

View File

@@ -14,7 +14,12 @@
from typing import TYPE_CHECKING
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
is_flax_available,
is_torch_available,
)
_import_structure = {}
@@ -23,6 +28,7 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
@@ -38,6 +44,7 @@ if is_torch_available():
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -51,6 +58,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
@@ -66,6 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
if is_flax_available():

View File

@@ -194,7 +194,12 @@ class BasicTransformerBlock(nn.Module):
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
@@ -339,6 +344,181 @@ class BasicTransformerBlock(nn.Module):
return hidden_states
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
norm_eps: float = 1e-5,
final_dropout: bool = False,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim, eps=norm_eps)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim, eps=norm_eps)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
dropout=dropout,
activation_fn="geglu",
final_dropout=final_dropout,
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
self.ff = FeedForward(
time_mix_inner_dim,
dropout=dropout,
activation_fn="geglu",
final_dropout=final_dropout,
)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.

View File

@@ -0,0 +1,672 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import BaseOutput, is_torch_version
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_utils import ModelMixin
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
def __init__(
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (
128,
256,
512,
512,
),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
alpha: float = 0.0,
merge_strategy: str = "learned",
conv_out_kernel_size=(3, 1, 1),
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
temb_channels = in_channels if norm_type == "spatial" else None
self.mid_block = MidBlockTemporalDecoder(
num_layers=self.layers_per_block,
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
attention_head_dim=block_out_channels[-1],
resnet_eps=1e-6,
temporal_resnet_eps=1e-5,
resnet_act_fn=act_fn,
norm_num_groups=norm_num_groups,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
merge_factor=alpha,
merge_strategy=merge_strategy,
)
# up
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlockTemporalDecoder(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
resnet_eps=1e-6,
temporal_resnet_eps=1e-5,
resnet_act_fn=act_fn,
norm_num_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
merge_factor=alpha,
merge_strategy=merge_strategy,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
if isinstance(conv_out_kernel_size, Iterable):
padding = [int(k // 2) for k in conv_out_kernel_size]
else:
padding = int(conv_out_kernel_size // 2)
self.conv_act = nn.SiLU()
self.conv_out = torch.nn.Conv2d(
in_channels=block_out_channels[0],
out_channels=out_channels,
kernel_size=3,
padding=1,
)
self.time_conv_out = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=conv_out_kernel_size,
padding=padding,
)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
num_frames: int = 1,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
latent_embeds,
num_frames,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
latent_embeds,
num_frames,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
latent_embeds,
num_frames,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
latent_embeds,
num_frames,
)
else:
# middle
sample = self.mid_block(
sample,
temb=latent_embeds,
num_frames=num_frames,
image_only_indicator=image_only_indicator,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(
sample,
temb=latent_embeds,
num_frames=num_frames,
image_only_indicator=image_only_indicator,
)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
batch_frames, channels, height, width = sample.shape
batch_size = batch_frames // num_frames
sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
sample = self.time_conv_out(sample)
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return sample
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
force_upcast: float = True,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
# pass init params to Decoder
self.decoder = TemporalDecoder(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, TemporalDecoder)):
module.gradient_checkpointing = value
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(
self, z: torch.FloatTensor, num_frames: int, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
batch_size = z.shape[0] // num_frames
# TODO: dont hardcode this
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
dec = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
num_frames: int,
return_dict: bool = True,
generator=None,
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice, num_frames // 2).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z, num_frames).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
num_frames: int = 1,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, num_frames=num_frames).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -165,7 +165,10 @@ class Upsample2D(nn.Module):
self.Conv2d_0 = conv
def forward(
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
@@ -379,7 +382,11 @@ class FirUpsample2D(nn.Module):
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d(
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
hidden_states,
weight,
stride=stride,
output_padding=output_padding,
padding=0,
)
output = upfirdn2d_native(
@@ -530,7 +537,14 @@ class KDownsample2D(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -553,7 +567,14 @@ class KUpsample2D(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -690,11 +711,19 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)
def forward(
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
@@ -866,7 +895,10 @@ class ResidualTemporalBlock1D(nn.Module):
def upsample_2d(
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
@@ -910,7 +942,10 @@ def upsample_2d(
def downsample_2d(
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
@@ -946,13 +981,20 @@ def downsample_2d(
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
hidden_states,
kernel.to(device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
def upfirdn2d_native(
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
tensor: torch.Tensor,
kernel: torch.Tensor,
up: int = 1,
down: int = 1,
pad: Tuple[int, int] = (0, 0),
) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
@@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
"""
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
def __init__(
self,
in_dim: int,
out_dim: Optional[int] = None,
dropout: float = 0.0,
norm_num_groups: int = 32,
):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
@@ -1016,7 +1064,9 @@ class TemporalConvLayer(nn.Module):
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
nn.GroupNorm(norm_num_groups, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv2 = nn.Sequential(
nn.GroupNorm(norm_num_groups, out_dim),
@@ -1058,3 +1108,295 @@ class TemporalConvLayer(nn.Module):
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
)
return hidden_states
class TemporalResnetBlock(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
eps: float = 1e-6,
non_linearity: str = "swish",
kernel_size: Optional[torch.FloatTensor] = (3, 1, 1),
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.output_scale_factor = output_scale_factor
linear_cls = nn.Linear
conv_cls = nn.Conv3d
padding = [k // 2 for k in kernel_size]
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = conv_cls(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
if temb_channels is not None:
self.time_emb_proj = linear_cls(temb_channels, out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(
out_channels,
conv_2d_out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
self.nonlinearity = get_activation(non_linearity)
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, :, None, None]
if temb is not None:
temb = temb.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# VideoResBlock
class SpatioTemporalResBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
pre_norm: bool = True,
eps: float = 1e-6,
temporal_eps: Optional[float] = None,
non_linearity: str = "swish",
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
output_scale_factor: float = 1.0,
kernel_size_3d: Optional[torch.FloatTensor] = (3, 1, 1),
merge_factor: float = 0.5,
merge_strategy="learned",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.spatial_res_block = ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=eps,
groups=groups,
dropout=dropout,
time_embedding_norm=time_embedding_norm,
non_linearity=non_linearity,
output_scale_factor=output_scale_factor,
pre_norm=pre_norm,
)
self.temporal_res_block = TemporalResnetBlock(
in_channels=out_channels if out_channels is not None else in_channels,
out_channels=out_channels if out_channels is not None else in_channels,
temb_channels=temb_channels,
eps=temporal_eps if temporal_eps is not None else eps,
groups=groups,
dropout=dropout,
non_linearity=non_linearity,
output_scale_factor=output_scale_factor,
kernel_size=kernel_size_3d,
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
image_only_indicator: Optional[torch.Tensor] = None,
scale: float = 1.0,
):
hidden_states = self.spatial_res_block(hidden_states, temb, scale=scale)
batch_frames, channels, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states_mix = (
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
)
hidden_states = (
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
)
if temb is not None:
temb = temb.reshape(batch_size, num_frames, -1)
hidden_states = self.temporal_res_block(hidden_states, temb)
hidden_states = self.time_mixer(
x_spatial=hidden_states_mix,
x_temporal=hidden_states,
image_only_indicator=image_only_indicator,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return hidden_states
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.merge_strategy = merge_strategy
self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
assert (
image_only_indicator is not None
), "Please provide image_only_indicator to use learned_with_images merge strategy"
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
torch.sigmoid(self.mix_factor)[..., None],
)
# (batch, channel, frames, height, width)
if ndims == 5:
alpha = alpha[:, None, :, None, None]
# (batch*frames, height*width, channels)
elif ndims == 3:
alpha = alpha.reshape(-1)[:, None, None]
else:
raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
alpha = alpha.to(x_spatial.dtype)
if self.switch_spatial_to_temporal_mix:
alpha = 1.0 - alpha
x = alpha * x_spatial + (1.0 - alpha) * x_temporal
return x

View File

@@ -19,8 +19,10 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .resnet import AlphaBlender
@dataclass
@@ -195,3 +197,229 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
return (output,)
return TransformerTemporalModelOutput(sample=output)
# VideoBlock
class TransformerSpatioTemporalModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: int = 320,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
norm_eps: float = 1e-5,
merge_factor: float = 0.5,
merge_strategy: str = "learned_with_images",
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
linear_cls = nn.Linear
# 2. Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps)
self.proj_in = linear_cls(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
norm_eps=norm_eps,
)
for d in range(num_layers)
]
)
time_mix_inner_dim = inner_dim
self.temporal_transformer_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(
inner_dim,
time_mix_inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
time_embed_dim = in_channels * 4
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
# TODO: should use out_channels for continuous projections
self.proj_out = linear_cls(inner_dim, in_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
num_frames: int,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
num_frames (`int`, *optional*, defaults to 1):
The number of frames to be processed per batch. This is used to reshape the hidden states.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
assert (
encoder_hidden_states.ndim == 3
), f"n dims of spatial context should be 3 but are {encoder_hidden_states.ndim}"
# 1. Input
batch_frames, channel, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
time_context = encoder_hidden_states
time_context_first_timestep = time_context[::num_frames]
time_context = time_context_first_timestep.repeat(height * width, 1, 1)
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
num_frames_emb = num_frames_emb.reshape(-1)
t_emb = self.time_proj(num_frames_emb)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
None,
encoder_hidden_states,
None,
timestep,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
else:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb
hidden_states_mix = temporal_block(
hidden_states_mix,
num_frames=num_frames,
encoder_hidden_states=time_context,
cross_attention_kwargs=cross_attention_kwargs,
)
hidden_states = self.time_mixer(
x_spatial=hidden_states,
x_temporal=hidden_states_mix,
image_only_indicator=image_only_indicator,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,859 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import (
USE_PEFT_BACKEND,
BaseOutput,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNetSpatioTemporalConditionOutput(BaseOutput):
"""
The output of [`UNetSpatioTemporalConditionModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor = None
class UNetSpatioTemporalConditionModel(
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin
):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
Dimension for the timestep embeddings.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, defaults to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_dim (`int`, *optional*, defaults to `None`):
An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
Optional activation function to use only once on the time embeddings before they are passed to the rest of
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str`, *optional*, defaults to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
*optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
otherwise.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 8,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str] = (
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
mid_block_type: Optional[str] = "UNetMidBlockSpatioTemporal",
up_block_types: Tuple[str] = (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
projection_class_embeddings_input_dim: int = 768,
addition_time_embed_dim: int = 256,
layers_per_block: Union[int, Tuple[int]] = 2,
mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
time_embedding_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
kernel_size_3d: Optional[torch.FloatTensor] = (3, 1, 1),
merge_factor: float = 0.5,
merge_strategy: str = "learned_with_images",
):
super().__init__()
self.sample_size = sample_size
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
)
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding,
)
# time
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos, downscale_freq_shift=0
)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
self.add_time_proj = Timesteps(
addition_time_embed_dim, flip_sin_to_cos, downscale_freq_shift=0
)
self.add_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(
down_block_types
)
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=1,
dropout=dropout,
kernel_size_3d=kernel_size_3d,
merge_factor=merge_factor,
merge_strategy=merge_strategy,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockSpatioTemporal(
block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block[-1],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dropout=dropout,
kernel_size_3d=kernel_size_3d,
merge_factor=merge_factor,
merge_strategy=merge_strategy,
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(
reversed(transformer_layers_per_block)
)
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[
min(i + 1, len(block_out_channels) - 1)
]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dropout=dropout,
kernel_size_3d=kernel_size_3d,
merge_factor=merge_factor,
merge_strategy=merge_strategy,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = get_activation(act_fn)
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0],
out_channels,
kernel_size=conv_out_kernel,
padding=conv_out_padding,
)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(
return_deprecated_lora=True
)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(
self,
processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
_remove_lora=False,
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(
processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnAddedKVProcessor()
elif all(
proc.__class__ in CROSS_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = (
num_sliceable_layers * [slice_size]
if not isinstance(slice_size, list)
else slice_size
)
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(
module: torch.nn.Module, slice_size: List[int]
):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if (
hasattr(upsample_block, k)
or getattr(upsample_block, k, None) is not None
):
setattr(upsample_block, k, None)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
r"""
The [`UNet2DConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
added_cond_kwargs: (`dict`):
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (
1 - encoder_attention_mask.to(sample.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
timesteps = timesteps.expand(batch_size)
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)
emb = emb + aug_emb
# Flatten the batch and frames dimensions
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0
)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
image_only_indicator = torch.zeros(
batch_size, num_frames, dtype=sample.dtype, device=sample.device
)
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
scale=lora_scale,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block(
hidden_states=sample,
temb=emb,
num_video_frames=num_frames,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
image_only_indicator=image_only_indicator,
)
else:
sample = self.mid_block(
sample,
temb=emb,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# 7. Reshape back to original shape
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)
return UNetSpatioTemporalConditionOutput(sample=sample)

View File

@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version
from ..utils.torch_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
from .unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
get_up_block,
)
@dataclass
@@ -122,11 +127,15 @@ class Encoder(nn.Module):
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(
block_out_channels[-1], conv_out_channels, 3, padding=1
)
self.gradient_checkpointing = False
@@ -155,9 +164,13 @@ class Encoder(nn.Module):
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample
)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample
)
else:
# down
@@ -267,14 +280,18 @@ class Decoder(nn.Module):
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
@@ -292,14 +309,20 @@ class Decoder(nn.Module):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
@@ -310,7 +333,9 @@ class Decoder(nn.Module):
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds
)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
@@ -350,7 +375,9 @@ class UpSample(nn.Module):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
self.deconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=4, stride=2, padding=1
)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `UpSample` class."""
@@ -394,9 +421,13 @@ class MaskConditionEncoder(nn.Module):
for l in range(len(out_channels)):
out_ch_ = out_channels[l]
if l == 0 or l == 1:
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
layers.append(
nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)
)
else:
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
layers.append(
nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)
)
in_ch_ = out_ch_
self.layers = nn.Sequential(*layers)
@@ -511,7 +542,9 @@ class MaskConditionDecoder(nn.Module):
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
@@ -540,7 +573,10 @@ class MaskConditionDecoder(nn.Module):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
@@ -548,17 +584,25 @@ class MaskConditionDecoder(nn.Module):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
create_custom_forward(self.condition_encoder),
masked_image,
mask,
use_reentrant=False,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
mask_ = nn.functional.interpolate(
mask, size=sample.shape[-2:], mode="nearest"
)
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
@@ -573,16 +617,22 @@ class MaskConditionDecoder(nn.Module):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder), masked_image, mask
create_custom_forward(self.condition_encoder),
masked_image,
mask,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
mask_ = nn.functional.interpolate(
mask, size=sample.shape[-2:], mode="nearest"
)
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else:
@@ -599,7 +649,9 @@ class MaskConditionDecoder(nn.Module):
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
mask_ = nn.functional.interpolate(
mask, size=sample.shape[-2:], mode="nearest"
)
sample = sample * mask_ + sample_ * (1 - mask_)
sample = up_block(sample, latent_embeds)
if image is not None and mask is not None:
@@ -671,7 +723,9 @@ class VectorQuantizer(nn.Module):
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
@@ -686,13 +740,17 @@ class VectorQuantizer(nn.Module):
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
def forward(
self, z: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
min_encoding_indices = torch.argmin(
torch.cdist(z_flattened, self.embedding.weight), dim=1
)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
@@ -700,9 +758,13 @@ class VectorQuantizer(nn.Module):
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
(z_q - z.detach()) ** 2
)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q: torch.FloatTensor = z + (z_q - z).detach()
@@ -711,16 +773,22 @@ class VectorQuantizer(nn.Module):
z_q = z_q.permute(0, 3, 1, 2).contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = min_encoding_indices.reshape(
z.shape[0], -1
) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
def get_codebook_entry(
self, indices: torch.LongTensor, shape: Tuple[int, ...]
) -> torch.FloatTensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
@@ -754,7 +822,10 @@ class DiagonalGaussianDistribution(object):
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
@@ -764,7 +835,10 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
@@ -775,11 +849,16 @@ class DiagonalGaussianDistribution(object):
dim=[1, 2, 3],
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
def nll(
self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
@@ -818,14 +897,27 @@ class EncoderTiny(nn.Module):
num_channels = block_out_channels[i]
if i == 0:
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
layers.append(
nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)
)
else:
layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
layers.append(
nn.Conv2d(
num_channels,
num_channels,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
)
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
layers.append(
nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)
)
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
@@ -841,9 +933,13 @@ class EncoderTiny(nn.Module):
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.layers), x, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.layers), x
)
else:
# scale image from [-1, 1] to [0, 1] to match TAESD convention
@@ -899,7 +995,15 @@ class DecoderTiny(nn.Module):
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
conv_out_channel = num_channels if not is_final_block else out_channels
layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
layers.append(
nn.Conv2d(
num_channels,
conv_out_channel,
kernel_size=3,
padding=1,
bias=is_final_block,
)
)
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
@@ -918,9 +1022,13 @@ class DecoderTiny(nn.Module):
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.layers), x, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.layers), x
)
else:
x = self.layers(x)

View File

@@ -145,6 +145,7 @@ else:
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionVideoPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
]
@@ -372,6 +373,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionVideoPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)

View File

@@ -47,6 +47,7 @@ else:
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
_import_structure["pipeline_stable_diffusion_video"] = ["StableDiffusionVideoPipeline"]
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
_import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"]
@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .pipeline_stable_diffusion_video import StableDiffusionVideoPipeline
from .pipeline_stable_unclip import StableUnCLIPPipeline
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
from .safety_checker import StableDiffusionSafetyChecker

View File

@@ -0,0 +1,588 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import BaseOutput, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
return outputs
@dataclass
class StableDiffusionVideoPipelineOutput(BaseOutput):
r"""
Output class for zero-shot text-to-video pipeline.
Args:
frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
"""
frames: Union[List[PIL.Image.Image], np.ndarray]
class StableDiffusionVideoPipeline(DiffusionPipeline):
r"""
Pipeline to generate video from an input image using Stable Video Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
unet ([`UNetSpatioTemporalConditionModel`]):
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images.
"""
model_cpu_offload_seq = "image_encoder->unet->vae"
def __init__(
self,
vae: AutoencoderKLTemporalDecoder,
image_encoder: CLIPVisionModelWithProjection,
unet: UNetSpatioTemporalConditionModel,
scheduler: KarrasDiffusionSchedulers,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_image(
self, image, device, num_videos_per_prompt, do_classifier_free_guidance
):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(
images=image, return_tensors="pt"
).pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
image_embeddings = image_embeddings.view(
bs_embed * num_videos_per_prompt, seq_len, -1
)
if do_classifier_free_guidance:
negative_image_embeddings = torch.zeros_like(image_embeddings)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
return image_embeddings
def _encode_vae_image(
self,
image: torch.Tensor,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
):
image = image.to(device=device)
image_latents = self.vae.encode(image).latent_dist.mode()
if do_classifier_free_guidance:
negative_image_latents = torch.zeros_like(image_latents)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_latents = torch.cat([negative_image_latents, image_latents])
# duplicate image_latents for each generation per prompt, using mps friendly method
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
return image_latents
def _get_add_time_ids(
self,
fps_id,
motion_bucket_id,
cond_aug,
dtype,
batch_size,
num_videos_per_prompt,
do_classifier_free_guidance,
):
add_time_ids = [fps_id, motion_bucket_id, cond_aug]
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(
add_time_ids
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
if do_classifier_free_guidance:
add_time_ids = torch.cat([add_time_ids, add_time_ids])
return add_time_ids
def decode_latents(self, latents, num_frames, decoding_t=14):
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
latents = latents.flatten(0, 1)
latents = 1 / self.vae.config.scaling_factor * latents
# decode decoding_t frames at a time to avoid OOM
frames = []
for i in range(0, latents.shape[0], decoding_t):
num_frames_in = latents[i : i + decoding_t].shape[0]
frame = self.vae.decode(latents[i : i + decoding_t], num_frames_in).sample
frames.append(frame)
frames = torch.cat(frames, dim=0)
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(
0, 2, 1, 3, 4
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
frames = frames.float()
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, image, height, width, callback_steps):
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(
self,
batch_size,
num_frames,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_frames,
num_channels_latents // 2,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
height: int = 576,
width: int = 1024,
num_frames: int = 14,
num_inference_steps: int = 50,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 2.5,
fps_id: int = 6,
motion_bucket_id: int = 127,
cond_aug: int = 0.02,
decoding_t: int = 14,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
r"""
The call function to the pipeline for generation.
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_frames (`int`, *optional*, defaults to 14):
The number of video frames to generate.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
min_guidance_scale (`float`, *optional*, defaults to 1.0):
The minimum guidance scale. Used for the classifier free guidance with first frame.
max_guidance_scale (`float`, *optional*, defaults to 2.5):
The maximum guidance scale. Used for the classifier free guidance with last frame.
fps_id (`int`, *optional*, defaults to 6):
The frame rate ID. Used as conditioning for the generation.
motion_bucket_id (`int`, *optional*, defaults to 127):
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
cond_aug (`int`, *optional*, defaults to 0.02):
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
decoding_t (`int`, *optional*, defaults to 14):
The number of frames to decode at a time. This is used to avoid OOM errors.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionVideoPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
Examples:
```py
from diffusers import StableDiffusionVideoPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableDiffusionVideoPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
image = image.resize((1024, 576))
frames = pipe(image, num_frames=14, decoding_t=8).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width, callback_steps)
# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = max_guidance_scale > 1.0
# 3. Encode input image
image_embeddings = self._encode_image(
image, device, num_videos_per_prompt, do_classifier_free_guidance
)
# 4. Encode input image using VAE
image = self.image_processor.preprocess(image, height=height, width=width)
image = image + cond_aug * torch.randn_like(image)
needs_upcasting = (
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
)
if needs_upcasting:
self.vae.to(dtype=torch.float32)
image_latents = self._encode_vae_image(
image, device, num_videos_per_prompt, do_classifier_free_guidance
)
image_latents = image_latents.to(image_embeddings.dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# Repeat the image latents for each frame so we can concatenate them with the noise
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
# 5. Get Added Time IDs
added_time_ids = self._get_add_time_ids(
fps_id,
motion_bucket_id,
cond_aug,
image_embeddings.dtype,
batch_size,
num_videos_per_prompt,
do_classifier_free_guidance,
)
added_time_ids = added_time_ids.to(device)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_frames,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare guidance scale
guidance_scale = torch.linspace(
min_guidance_scale, max_guidance_scale, num_frames
).unsqueeze(0)
guidance_scale = guidance_scale.to(device, latents.dtype)
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
guidance_scale = _append_dims(guidance_scale, latents.ndim)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# Concatenate image_latents over channels dimention
latent_model_input = torch.cat(
[latent_model_input, image_latents], dim=2
)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=image_embeddings,
added_time_ids=added_time_ids,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
frames = self.decode_latents(latents, num_frames, decoding_t)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
if not output_type == "latent":
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
else:
frames = latents
self.maybe_free_model_hooks()
if not return_dict:
return frames
return StableDiffusionVideoPipelineOutput(frames=frames)

View File

@@ -323,8 +323,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -358,8 +358,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -358,8 +358,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -357,8 +357,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -144,7 +144,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
):
if trained_betas is not None:
@@ -268,6 +271,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
# when timestep_type is continuous, we need to convert the timesteps to continuous values using c_noise
if self.config.timestep_type == "continuous":
timesteps = np.array([self.get_scalings(sigma)[-1] for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -301,8 +308,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
@@ -311,6 +330,33 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def get_scalings(self, sigma=None):
"""
Get the scalings for the current timestep.
Returns:
`torch.FloatTensor`:
The scaling factors for the current timestep.
"""
if sigma is None:
sigma = self.sigmas[self.step_index]
if self.config.prediction_type == "epsilon":
c_skip = torch.ones_like(sigma, device=sigma.device)
c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
elif self.config.prediction_type == "v_prediction":
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = 0.25 * math.log(sigma)
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
return c_skip, c_out, c_in, c_noise
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -391,6 +437,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
c_skip, c_out, _, _ = self.get_scalings()
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
@@ -412,8 +459,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
pred_original_sample = model_output * c_out + sample * c_skip
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"

View File

@@ -303,8 +303,20 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -324,8 +324,20 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -335,8 +335,20 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -337,8 +337,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)

View File

@@ -32,6 +32,21 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderTiny(metaclass=DummyObject):
_backends = ["torch"]
@@ -272,6 +287,21 @@ class UNetMotionModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class UNetSpatioTemporalConditionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class VQModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1022,6 +1022,21 @@ class StableDiffusionUpscalePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLAdapterPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -3,7 +3,7 @@ import random
import struct
import tempfile
from contextlib import contextmanager
from typing import List
from typing import List, Union
import numpy as np
import PIL.Image
@@ -115,7 +115,9 @@ def export_to_obj(mesh, output_obj_path: str = None):
f.writelines("\n".join(combined_data))
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
def export_to_video(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
) -> str:
if is_opencv_available():
import cv2
else:
@@ -123,9 +125,12 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
if isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
h, w, c = video_frames[0].shape
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
for i in range(len(video_frames)):
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)

View File

@@ -0,0 +1,354 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_all_close,
torch_device,
)
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample"
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
@property
def fps_id(self):
return 6
@property
def motion_bucket_id(self):
return 127
@property
def cond_aug(self):
return 0.02
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": (
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
"up_block_types": (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_use_linear_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["use_linear_projection"] = True
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (32, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_simple_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
init_dict["class_embed_type"] = "simple_projection"
init_dict["projection_class_embeddings_input_dim"] = sample_size
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_class_embeddings_concat(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
init_dict["class_embed_type"] = "simple_projection"
init_dict["projection_class_embeddings_input_dim"] = sample_size
init_dict["class_embeddings_concat"] = True
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model.set_attention_slice("auto")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
model.set_attention_slice("max")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
model.set_attention_slice(2)
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
def test_model_sliceable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
def check_sliceable_dim_attr(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
assert isinstance(module.sliceable_head_dim, int)
for child in module.children():
check_sliceable_dim_attr(child)
# retrieve number of attention layers
for module in model.children():
check_sliceable_dim_attr(module)
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample = model(**inputs_dict).sample
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4

View File

@@ -23,6 +23,7 @@ from parameterized import parameterized
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
StableDiffusionPipeline,
@@ -195,10 +196,16 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
self.assertTrue(
torch_all_close(
param.grad.data, named_params_2[name].grad.data, atol=5e-5
)
)
def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
model, loading_info = AutoencoderKL.from_pretrained(
"fusing/autoencoder-kl-dummy", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -248,17 +255,39 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
)
elif torch_device == "cpu":
expected_output_slice = torch.tensor(
[-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
[
-0.1352,
0.0878,
0.0419,
-0.0818,
-0.1069,
0.0688,
-0.1458,
-0.4446,
-0.0026,
]
)
else:
expected_output_slice = torch.tensor(
[-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
[
-0.2421,
0.4642,
0.2507,
-0.0438,
0.0682,
0.3160,
-0.2018,
-0.0727,
0.2485,
]
)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AsymmetricAutoencoderKLTests(
ModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
@@ -336,7 +365,9 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
generator = torch.Generator("cpu")
if seed is not None:
generator.manual_seed(0)
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
image = randn_tensor(
(4, 3, 32, 32), generator=generator, device=torch.device(torch_device)
)
return {"sample": image, "generator": generator}
@@ -364,6 +395,98 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
...
class AutoncoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
@property
def dummy_input(self):
batch_size = 3
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
num_frames = 3
return {"sample": image, "num_frames": num_frames}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [8, 16],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"latent_channels": 4,
"norm_num_groups": 4,
"layers_per_block": 2,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_forward_signature(self):
pass
def test_training(self):
pass
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
self.assertTrue(
torch_all_close(
param.grad.data, named_params_2[name].grad.data, atol=5e-5
)
)
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self):
@@ -377,10 +500,16 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
image = (
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
.to(torch_device)
.to(dtype)
)
return image
def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
def get_sd_vae_model(
self, model_id="hf-internal-testing/taesd-diffusers", fp16=False
):
torch_dtype = torch.float16 if fp16 else torch.float32
model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
@@ -414,7 +543,9 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
expected_output_slice = torch.tensor(
[0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382]
)
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
@@ -454,7 +585,11 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
image = (
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
.to(torch_device)
.to(dtype)
)
return image
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
@@ -503,7 +638,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
expected_output_slice = torch.tensor(
expected_slice_mps if torch_device == "mps" else expected_slice
)
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
@@ -557,7 +694,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
expected_output_slice = torch.tensor(
expected_slice_mps if torch_device == "mps" else expected_slice
)
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
@@ -609,7 +748,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -627,7 +769,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -660,7 +805,9 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
dist = model.encode(image).latent_dist
sample = dist.sample(generator=generator)
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
assert list(sample.shape) == [image.shape[0], 4] + [
i // 8 for i in image.shape[2:]
]
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
expected_output_slice = torch.tensor(expected_slice)
@@ -701,10 +848,16 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
image = (
torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape)))
.to(torch_device)
.to(dtype)
)
return image
def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
def get_sd_vae_model(
self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False
):
revision = "main"
torch_dtype = torch.float32
@@ -749,7 +902,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
expected_output_slice = torch.tensor(
expected_slice_mps if torch_device == "mps" else expected_slice
)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@@ -779,7 +934,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
expected_output_slice = torch.tensor(
expected_slice_mps if torch_device == "mps" else expected_slice
)
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
@@ -808,7 +965,10 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -841,7 +1001,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
dist = model.encode(image).latent_dist
sample = dist.sample(generator=generator)
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
assert list(sample.shape) == [image.shape[0], 4] + [
i // 8 for i in image.shape[2:]
]
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
expected_output_slice = torch.tensor(expected_slice)
@@ -860,37 +1022,52 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
@torch.no_grad()
def test_encode_decode(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
vae = ConsistencyDecoderVAE.from_pretrained(
"openai/consistency-decoder"
) # TODO - update
vae.to(torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
None, :, :, :
].cuda()
image = torch.from_numpy(
np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1
)[None, :, :, :].cuda()
latent = vae.encode(image).latent_dist.mean
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
sample = vae.decode(
latent, generator=torch.Generator("cpu").manual_seed(0)
).sample
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
expected_output = torch.tensor(
[-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024]
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_sd(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
vae = ConsistencyDecoderVAE.from_pretrained(
"openai/consistency-decoder"
) # TODO - update
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None
)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
"horse",
num_inference_steps=2,
output_type="pt",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
expected_output = torch.tensor(
[0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759]
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
@@ -905,18 +1082,23 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = (
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
torch.from_numpy(
np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1
)[None, :, :, :]
.half()
.cuda()
)
latent = vae.encode(image).latent_dist.mean
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
sample = vae.decode(
latent, generator=torch.Generator("cpu").manual_seed(0)
).sample
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
dtype=torch.float16,
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
@@ -926,17 +1108,24 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
vae=vae,
safety_checker=None,
)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
"horse",
num_inference_steps=2,
output_type="pt",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
dtype=torch.float16,
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)

View File

@@ -0,0 +1,261 @@
import gc
import random
import unittest
import numpy as np
import torch
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
)
import diffusers
from diffusers import (
AutoencoderKLTemporalDecoder,
DDIMScheduler,
StableDiffusionVideoPipeline,
UNetSpatioTemporalConditionModel,
)
from diffusers.utils import load_image, logging
from diffusers.utils.testing_utils import (
floats_tensor,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
class StableDiffusionVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionVideoPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback",
"callback_steps",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNetSpatioTemporalConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=(
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"),
cross_attention_dim=32,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
)
torch.manual_seed(0)
vae = AutoencoderKLTemporalDecoder(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"feature_extractor": feature_extractor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"image": image,
"num_inference_steps": 2,
"guidance_scale": 7.5,
"output_type": "pt",
}
return inputs
@unittest.skip("Attention slicing is not enabled in this pipeline")
def test_attention_slicing_forward_pass(self):
pass
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
additional_params_copy_to_batched_inputs=["num_inference_steps"],
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for components in pipe.components.values():
if hasattr(components, "set_default_attn_processor"):
components.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
if name == "prompt":
len_prompt = len(value)
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
batched_inputs[name][-1] = 100 * "very long"
else:
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
output = pipe(**inputs)
output_batch = pipe(**batched_inputs)
assert output_batch[0].shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
@slow
@require_torch_gpu
class StableDiffusionVideoPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_sd_video(self):
pipe = StableDiffusionVideoPipeline.from_pretrained("diffusers/svd-test")
pipe = pipe.to(torch_device)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
)
generator = torch.Generator("cpu").manual_seed(0)
num_frames = 3
output = pipe(
image=image,
num_frames=num_frames,
generator=generator,
num_inference_steps=3,
output_type="np",
)
image = output.frames[0]
assert image.shape == (num_frames, 576, 1024, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.8592, 0.8645, 0.8499, 0.8722, 0.8769, 0.8421, 0.8557, 0.8528, 0.8285])
assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3