mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
97 Commits
hidream-im
...
cog-test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
78c125f76f | ||
|
|
9e3b0fe107 | ||
|
|
878f609aa5 | ||
|
|
32da2e7673 | ||
|
|
9a0b906518 | ||
|
|
b3428ad5f5 | ||
|
|
70a54a8230 | ||
|
|
6f4e60b58b | ||
|
|
e4d65ccdd7 | ||
|
|
22dcceb858 | ||
|
|
9c6b8894ff | ||
|
|
cf7369d418 | ||
|
|
123ecef2b9 | ||
|
|
511c9ef560 | ||
|
|
fb6130fe90 | ||
|
|
2d9602cc96 | ||
|
|
1b1b737acb | ||
|
|
311845fc77 | ||
|
|
01c2dff338 | ||
|
|
03580c07b9 | ||
|
|
fd11c0fbee | ||
|
|
92c8c00756 | ||
|
|
5781e017dd | ||
|
|
90aa8be534 | ||
|
|
2f1b7870e2 | ||
|
|
7360ea1d03 | ||
|
|
ba1855c07e | ||
|
|
1b1b26b65c | ||
|
|
312f7dc4fd | ||
|
|
ba4223ac3b | ||
|
|
6988cc3a86 | ||
|
|
fa7fa9cced | ||
|
|
61c6da076a | ||
|
|
c7ee165c4f | ||
|
|
d99528be94 | ||
|
|
fd0831c52c | ||
|
|
477e12b235 | ||
|
|
b42b079213 | ||
|
|
21509aa7f5 | ||
|
|
65f6211f1f | ||
|
|
3def90523d | ||
|
|
551c884acd | ||
|
|
ec53a30a0e | ||
|
|
71e7c82ae8 | ||
|
|
c33dd0213b | ||
|
|
e12458e16c | ||
|
|
77558f31bf | ||
|
|
41da084fbe | ||
|
|
4c2e8870e6 | ||
|
|
fe6f5d6419 | ||
|
|
d0b8db2b11 | ||
|
|
351d1f009e | ||
|
|
a31db5f952 | ||
|
|
03ee7cd109 | ||
|
|
712ddbeac6 | ||
|
|
03c28eef5b | ||
|
|
e05f83479c | ||
|
|
bb4740ce29 | ||
|
|
2956866ef4 | ||
|
|
4498cfc98c | ||
|
|
a449ceb3ef | ||
|
|
45f7127ade | ||
|
|
73469f9562 | ||
|
|
d45d199b99 | ||
|
|
e67cc5ae47 | ||
|
|
470815cefa | ||
|
|
5f183bfe27 | ||
|
|
c43a8f5b2b | ||
|
|
9f9d0cbb83 | ||
|
|
2be7469821 | ||
|
|
3ae9413966 | ||
|
|
ec9508c83b | ||
|
|
6bcafcbaa6 | ||
|
|
b3052807e5 | ||
|
|
73b041e7a9 | ||
|
|
1c661ce3d4 | ||
|
|
8fe54bcd26 | ||
|
|
ee40f0e1ca | ||
|
|
0980f4dcd2 | ||
|
|
71bcb1e1c5 | ||
|
|
dfeb32975d | ||
|
|
d83c1f8447 | ||
|
|
21a0fc1b0d | ||
|
|
16967589d8 | ||
|
|
d963b1aaa4 | ||
|
|
e982881716 | ||
|
|
cb5348a0c2 | ||
|
|
aff72ec5dc | ||
|
|
dc7e6e814f | ||
|
|
a3d827fb8d | ||
|
|
84ff56eb90 | ||
|
|
45cb1f92d3 | ||
|
|
59e6669f6d | ||
|
|
bb917755ee | ||
|
|
bd6efd5fe4 | ||
|
|
c341786f3e | ||
|
|
c8e5491be0 |
@@ -22,6 +22,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
|||||||
|
|
||||||
## Supported pipelines
|
## Supported pipelines
|
||||||
|
|
||||||
|
- [`CogVideoXPipeline`]
|
||||||
- [`StableDiffusionPipeline`]
|
- [`StableDiffusionPipeline`]
|
||||||
- [`StableDiffusionImg2ImgPipeline`]
|
- [`StableDiffusionImg2ImgPipeline`]
|
||||||
- [`StableDiffusionInpaintPipeline`]
|
- [`StableDiffusionInpaintPipeline`]
|
||||||
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
|||||||
- [`UNet2DConditionModel`]
|
- [`UNet2DConditionModel`]
|
||||||
- [`StableCascadeUNet`]
|
- [`StableCascadeUNet`]
|
||||||
- [`AutoencoderKL`]
|
- [`AutoencoderKL`]
|
||||||
|
- [`AutoencoderKLCogVideoX`]
|
||||||
- [`ControlNetModel`]
|
- [`ControlNetModel`]
|
||||||
- [`SD3Transformer2DModel`]
|
- [`SD3Transformer2DModel`]
|
||||||
|
|
||||||
|
|||||||
69
docs/source/en/api/models/autoencoderkl_cogvideox.md
Normal file
69
docs/source/en/api/models/autoencoderkl_cogvideox.md
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License. -->
|
||||||
|
|
||||||
|
# AutoencoderKLCogVideoX
|
||||||
|
|
||||||
|
The 3D variational autoencoder (VAE) model with KL loss using CogVideoX.
|
||||||
|
|
||||||
|
## Loading from the original format
|
||||||
|
|
||||||
|
By default, the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from diffusers import AutoencoderKLCogVideoX
|
||||||
|
|
||||||
|
url = "THUDM/CogVideoX-2b" # can also be a local file
|
||||||
|
model = AutoencoderKLCogVideoX.from_single_file(url)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## AutoencoderKLCogVideoX
|
||||||
|
|
||||||
|
[[autodoc]] AutoencoderKLCogVideoX
|
||||||
|
- decode
|
||||||
|
- encode
|
||||||
|
- all
|
||||||
|
|
||||||
|
## CogVideoXSafeConv3d
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXSafeConv3d
|
||||||
|
|
||||||
|
## CogVideoXCausalConv3d
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXCausalConv3d
|
||||||
|
|
||||||
|
## CogVideoXSpatialNorm3D
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXSpatialNorm3D
|
||||||
|
|
||||||
|
## CogVideoXResnetBlock3D
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXResnetBlock3D
|
||||||
|
|
||||||
|
## CogVideoXDownBlock3D
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXDownBlock3D
|
||||||
|
|
||||||
|
## CogVideoXMidBlock3D
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXMidBlock3D
|
||||||
|
|
||||||
|
## CogVideoXUpBlock3D
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXUpBlock3D
|
||||||
|
|
||||||
|
## CogVideoXEncoder3D
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXEncoder3D
|
||||||
|
|
||||||
|
## CogVideoXDecoder3D
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXDecoder3D
|
||||||
18
docs/source/en/api/models/cogvideox_transformer3d.md
Normal file
18
docs/source/en/api/models/cogvideox_transformer3d.md
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License. -->
|
||||||
|
|
||||||
|
## CogVideoXTransformer3DModel
|
||||||
|
|
||||||
|
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX).
|
||||||
|
|
||||||
|
## CogVideoXTransformer3DModel
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXTransformer3DModel
|
||||||
79
docs/source/en/api/pipelines/cogvideox.md
Normal file
79
docs/source/en/api/pipelines/cogvideox.md
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
## TODO: The paper is still being written.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# CogVideoX
|
||||||
|
|
||||||
|
[TODO]() from Tsinghua University & ZhipuAI.
|
||||||
|
|
||||||
|
The abstract from the paper is:
|
||||||
|
|
||||||
|
The paper is still being written.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
### Inference
|
||||||
|
|
||||||
|
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||||
|
|
||||||
|
First, load the pipeline:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import LattePipeline
|
||||||
|
|
||||||
|
pipeline = LattePipeline.from_pretrained(
|
||||||
|
"THUDM/CogVideoX-2b", torch_dtype=torch.float16
|
||||||
|
).to("cuda")
|
||||||
|
```
|
||||||
|
|
||||||
|
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||||
|
pipeline.vae.to(memory_format=torch.channels_last)
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, compile the components and run inference:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipeline.transformer = torch.compile(pipeline.transformer)
|
||||||
|
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
|
||||||
|
|
||||||
|
# CogVideoX works very well with long and well-described prompts
|
||||||
|
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
|
||||||
|
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
The [benchmark](TODO: link) results on an 80GB A100 machine are:
|
||||||
|
|
||||||
|
```
|
||||||
|
Without torch.compile(): Average inference time: TODO seconds.
|
||||||
|
With torch.compile(): Average inference time: TODO seconds.
|
||||||
|
```
|
||||||
|
|
||||||
|
## CogVideoXPipeline
|
||||||
|
|
||||||
|
[[autodoc]] CogVideoXPipeline
|
||||||
|
- all
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
## CogVideoXPipelineOutput
|
||||||
|
[[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput
|
||||||
222
scripts/convert_cogvideox_to_diffusers.py
Normal file
222
scripts/convert_cogvideox_to_diffusers.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
||||||
|
|
||||||
|
|
||||||
|
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
|
to_q_key = key.replace("query_key_value", "to_q")
|
||||||
|
to_k_key = key.replace("query_key_value", "to_k")
|
||||||
|
to_v_key = key.replace("query_key_value", "to_v")
|
||||||
|
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
|
||||||
|
state_dict[to_q_key] = to_q
|
||||||
|
state_dict[to_k_key] = to_k
|
||||||
|
state_dict[to_v_key] = to_v
|
||||||
|
state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
|
layer_id, weight_or_bias = key.split(".")[-2:]
|
||||||
|
|
||||||
|
if "query" in key:
|
||||||
|
new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
|
||||||
|
elif "key" in key:
|
||||||
|
new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
|
||||||
|
|
||||||
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
|
layer_id, _, weight_or_bias = key.split(".")[-3:]
|
||||||
|
|
||||||
|
weights_or_biases = state_dict[key].chunk(12, dim=0)
|
||||||
|
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
|
||||||
|
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
|
||||||
|
|
||||||
|
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
|
||||||
|
state_dict[norm1_key] = norm1_weights_or_biases
|
||||||
|
|
||||||
|
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
|
||||||
|
state_dict[norm2_key] = norm2_weights_or_biases
|
||||||
|
|
||||||
|
state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
|
state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
|
key_split = key.split(".")
|
||||||
|
layer_index = int(key_split[2])
|
||||||
|
replace_layer_index = 4 - 1 - layer_index
|
||||||
|
|
||||||
|
key_split[1] = "up_blocks"
|
||||||
|
key_split[2] = str(replace_layer_index)
|
||||||
|
new_key = ".".join(key_split)
|
||||||
|
|
||||||
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||||
|
"transformer.final_layernorm": "norm_final",
|
||||||
|
"transformer": "transformer_blocks",
|
||||||
|
"attention": "attn1",
|
||||||
|
"mlp": "ff.net",
|
||||||
|
"dense_h_to_4h": "0.proj",
|
||||||
|
"dense_4h_to_h": "2",
|
||||||
|
".layers": "",
|
||||||
|
"dense": "to_out.0",
|
||||||
|
"input_layernorm": "norm1.norm",
|
||||||
|
"post_attn1_layernorm": "norm2.norm",
|
||||||
|
"time_embed.0": "time_embedding.linear_1",
|
||||||
|
"time_embed.2": "time_embedding.linear_2",
|
||||||
|
"mixins.patch_embed": "patch_embed",
|
||||||
|
"mixins.final_layer.norm_final": "norm_out.norm",
|
||||||
|
"mixins.final_layer.linear": "proj_out",
|
||||||
|
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||||
|
}
|
||||||
|
|
||||||
|
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||||
|
"query_key_value": reassign_query_key_value_inplace,
|
||||||
|
"query_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||||
|
"key_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||||
|
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
|
||||||
|
"embed_tokens": remove_keys_inplace,
|
||||||
|
}
|
||||||
|
|
||||||
|
VAE_KEYS_RENAME_DICT = {
|
||||||
|
"block.": "resnets.",
|
||||||
|
"down.": "down_blocks.",
|
||||||
|
"downsample": "downsamplers.0",
|
||||||
|
"upsample": "upsamplers.0",
|
||||||
|
"nin_shortcut": "conv_shortcut",
|
||||||
|
"encoder.mid.block_1": "encoder.mid_block.resnets.0",
|
||||||
|
"encoder.mid.block_2": "encoder.mid_block.resnets.1",
|
||||||
|
"decoder.mid.block_1": "decoder.mid_block.resnets.0",
|
||||||
|
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
|
||||||
|
}
|
||||||
|
|
||||||
|
VAE_SPECIAL_KEYS_REMAP = {
|
||||||
|
"loss": remove_keys_inplace,
|
||||||
|
"up.": replace_up_keys_inplace,
|
||||||
|
}
|
||||||
|
|
||||||
|
TOKENIZER_MAX_LENGTH = 226
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
state_dict = saved_dict
|
||||||
|
if "model" in saved_dict.keys():
|
||||||
|
state_dict = state_dict["model"]
|
||||||
|
if "module" in saved_dict.keys():
|
||||||
|
state_dict = state_dict["module"]
|
||||||
|
if "state_dict" in saved_dict.keys():
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||||
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_transformer(ckpt_path: str):
|
||||||
|
PREFIX_KEY = "model.diffusion_model."
|
||||||
|
|
||||||
|
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||||
|
transformer = CogVideoXTransformer3DModel()
|
||||||
|
|
||||||
|
for key in list(original_state_dict.keys()):
|
||||||
|
new_key = key[len(PREFIX_KEY) :]
|
||||||
|
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||||
|
new_key = new_key.replace(replace_key, rename_key)
|
||||||
|
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||||
|
|
||||||
|
for key in list(original_state_dict.keys()):
|
||||||
|
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||||
|
if special_key not in key:
|
||||||
|
continue
|
||||||
|
handler_fn_inplace(key, original_state_dict)
|
||||||
|
|
||||||
|
transformer.load_state_dict(original_state_dict, strict=True)
|
||||||
|
return transformer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_vae(ckpt_path: str):
|
||||||
|
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||||
|
vae = AutoencoderKLCogVideoX()
|
||||||
|
|
||||||
|
for key in list(original_state_dict.keys()):
|
||||||
|
new_key = key[:]
|
||||||
|
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||||
|
new_key = new_key.replace(replace_key, rename_key)
|
||||||
|
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||||
|
|
||||||
|
for key in list(original_state_dict.keys()):
|
||||||
|
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
||||||
|
if special_key not in key:
|
||||||
|
continue
|
||||||
|
handler_fn_inplace(key, original_state_dict)
|
||||||
|
|
||||||
|
vae.load_state_dict(original_state_dict, strict=True)
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||||
|
)
|
||||||
|
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||||
|
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||||
|
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
transformer = None
|
||||||
|
vae = None
|
||||||
|
|
||||||
|
if args.transformer_ckpt_path is not None:
|
||||||
|
transformer = convert_transformer(args.transformer_ckpt_path)
|
||||||
|
if args.vae_ckpt_path is not None:
|
||||||
|
vae = convert_vae(args.vae_ckpt_path)
|
||||||
|
|
||||||
|
text_encoder_id = "google/t5-v1_1-xxl"
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||||
|
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||||
|
|
||||||
|
scheduler = CogVideoXDDIMScheduler.from_config(
|
||||||
|
{
|
||||||
|
"snr_shift_scale": 3.0,
|
||||||
|
"beta_end": 0.012,
|
||||||
|
"beta_schedule": "scaled_linear",
|
||||||
|
"beta_start": 0.00085,
|
||||||
|
"clip_sample": False,
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"prediction_type": "v_prediction",
|
||||||
|
"rescale_betas_zero_snr": True,
|
||||||
|
"set_alpha_to_one": True,
|
||||||
|
"timestep_spacing": "linspace",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = CogVideoXPipeline(
|
||||||
|
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.fp16:
|
||||||
|
pipe = pipe.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
|
||||||
@@ -78,9 +78,11 @@ else:
|
|||||||
"AsymmetricAutoencoderKL",
|
"AsymmetricAutoencoderKL",
|
||||||
"AuraFlowTransformer2DModel",
|
"AuraFlowTransformer2DModel",
|
||||||
"AutoencoderKL",
|
"AutoencoderKL",
|
||||||
|
"AutoencoderKLCogVideoX",
|
||||||
"AutoencoderKLTemporalDecoder",
|
"AutoencoderKLTemporalDecoder",
|
||||||
"AutoencoderOobleck",
|
"AutoencoderOobleck",
|
||||||
"AutoencoderTiny",
|
"AutoencoderTiny",
|
||||||
|
"CogVideoXTransformer3DModel",
|
||||||
"ConsistencyDecoderVAE",
|
"ConsistencyDecoderVAE",
|
||||||
"ControlNetModel",
|
"ControlNetModel",
|
||||||
"ControlNetXSAdapter",
|
"ControlNetXSAdapter",
|
||||||
@@ -154,6 +156,8 @@ else:
|
|||||||
[
|
[
|
||||||
"AmusedScheduler",
|
"AmusedScheduler",
|
||||||
"CMStochasticIterativeScheduler",
|
"CMStochasticIterativeScheduler",
|
||||||
|
"CogVideoXDDIMScheduler",
|
||||||
|
"CogVideoXDPMScheduler",
|
||||||
"DDIMInverseScheduler",
|
"DDIMInverseScheduler",
|
||||||
"DDIMParallelScheduler",
|
"DDIMParallelScheduler",
|
||||||
"DDIMScheduler",
|
"DDIMScheduler",
|
||||||
@@ -249,6 +253,7 @@ else:
|
|||||||
"ChatGLMModel",
|
"ChatGLMModel",
|
||||||
"ChatGLMTokenizer",
|
"ChatGLMTokenizer",
|
||||||
"CLIPImageProjection",
|
"CLIPImageProjection",
|
||||||
|
"CogVideoXPipeline",
|
||||||
"CycleDiffusionPipeline",
|
"CycleDiffusionPipeline",
|
||||||
"FluxPipeline",
|
"FluxPipeline",
|
||||||
"HunyuanDiTControlNetPipeline",
|
"HunyuanDiTControlNetPipeline",
|
||||||
@@ -524,9 +529,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
AsymmetricAutoencoderKL,
|
AsymmetricAutoencoderKL,
|
||||||
AuraFlowTransformer2DModel,
|
AuraFlowTransformer2DModel,
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
AutoencoderKLCogVideoX,
|
||||||
AutoencoderKLTemporalDecoder,
|
AutoencoderKLTemporalDecoder,
|
||||||
AutoencoderOobleck,
|
AutoencoderOobleck,
|
||||||
AutoencoderTiny,
|
AutoencoderTiny,
|
||||||
|
CogVideoXTransformer3DModel,
|
||||||
ConsistencyDecoderVAE,
|
ConsistencyDecoderVAE,
|
||||||
ControlNetModel,
|
ControlNetModel,
|
||||||
ControlNetXSAdapter,
|
ControlNetXSAdapter,
|
||||||
@@ -597,6 +604,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .schedulers import (
|
from .schedulers import (
|
||||||
AmusedScheduler,
|
AmusedScheduler,
|
||||||
CMStochasticIterativeScheduler,
|
CMStochasticIterativeScheduler,
|
||||||
|
CogVideoXDDIMScheduler,
|
||||||
|
CogVideoXDPMScheduler,
|
||||||
DDIMInverseScheduler,
|
DDIMInverseScheduler,
|
||||||
DDIMParallelScheduler,
|
DDIMParallelScheduler,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
@@ -673,6 +682,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
ChatGLMModel,
|
ChatGLMModel,
|
||||||
ChatGLMTokenizer,
|
ChatGLMTokenizer,
|
||||||
CLIPImageProjection,
|
CLIPImageProjection,
|
||||||
|
CogVideoXPipeline,
|
||||||
CycleDiffusionPipeline,
|
CycleDiffusionPipeline,
|
||||||
FluxPipeline,
|
FluxPipeline,
|
||||||
HunyuanDiTControlNetPipeline,
|
HunyuanDiTControlNetPipeline,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
|||||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||||
|
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||||
@@ -41,6 +42,7 @@ if is_torch_available():
|
|||||||
_import_structure["embeddings"] = ["ImageProjection"]
|
_import_structure["embeddings"] = ["ImageProjection"]
|
||||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||||
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
||||||
|
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
|
||||||
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
|
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
|
||||||
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
|
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||||
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
|
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
|
||||||
@@ -77,6 +79,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .autoencoders import (
|
from .autoencoders import (
|
||||||
AsymmetricAutoencoderKL,
|
AsymmetricAutoencoderKL,
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
AutoencoderKLCogVideoX,
|
||||||
AutoencoderKLTemporalDecoder,
|
AutoencoderKLTemporalDecoder,
|
||||||
AutoencoderOobleck,
|
AutoencoderOobleck,
|
||||||
AutoencoderTiny,
|
AutoencoderTiny,
|
||||||
@@ -92,6 +95,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .modeling_utils import ModelMixin
|
from .modeling_utils import ModelMixin
|
||||||
from .transformers import (
|
from .transformers import (
|
||||||
AuraFlowTransformer2DModel,
|
AuraFlowTransformer2DModel,
|
||||||
|
CogVideoXTransformer3DModel,
|
||||||
DiTTransformer2DModel,
|
DiTTransformer2DModel,
|
||||||
DualTransformer2DModel,
|
DualTransformer2DModel,
|
||||||
FluxTransformer2DModel,
|
FluxTransformer2DModel,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||||
from .autoencoder_kl import AutoencoderKL
|
from .autoencoder_kl import AutoencoderKL
|
||||||
|
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||||
from .autoencoder_oobleck import AutoencoderOobleck
|
from .autoencoder_oobleck import AutoencoderOobleck
|
||||||
from .autoencoder_tiny import AutoencoderTiny
|
from .autoencoder_tiny import AutoencoderTiny
|
||||||
|
|||||||
1021
src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Normal file
1021
src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -285,6 +285,74 @@ class KDownsample2D(nn.Module):
|
|||||||
return F.conv2d(inputs, weight, stride=2)
|
return F.conv2d(inputs, weight, stride=2)
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXDownsample3D(nn.Module):
|
||||||
|
# Todo: Wait for paper relase.
|
||||||
|
r"""
|
||||||
|
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`):
|
||||||
|
Number of channels in the input image.
|
||||||
|
out_channels (`int`):
|
||||||
|
Number of channels produced by the convolution.
|
||||||
|
kernel_size (`int`, defaults to `3`):
|
||||||
|
Size of the convolving kernel.
|
||||||
|
stride (`int`, defaults to `2`):
|
||||||
|
Stride of the convolution.
|
||||||
|
padding (`int`, defaults to `0`):
|
||||||
|
Padding added to all four sides of the input.
|
||||||
|
compress_time (`bool`, defaults to `False`):
|
||||||
|
Whether or not to compress the time dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
stride: int = 2,
|
||||||
|
padding: int = 0,
|
||||||
|
compress_time: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
|
self.compress_time = compress_time
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.compress_time:
|
||||||
|
batch_size, channels, frames, height, width = x.shape
|
||||||
|
|
||||||
|
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
||||||
|
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
||||||
|
|
||||||
|
if x.shape[-1] % 2 == 1:
|
||||||
|
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||||
|
if x_rest.shape[-1] > 0:
|
||||||
|
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
||||||
|
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||||
|
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
||||||
|
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||||
|
else:
|
||||||
|
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
||||||
|
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||||
|
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
||||||
|
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||||
|
|
||||||
|
# Pad the tensor
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
x = F.pad(x, pad, mode="constant", value=0)
|
||||||
|
batch_size, channels, frames, height, width = x.shape
|
||||||
|
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
||||||
|
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
||||||
|
x = self.conv(x)
|
||||||
|
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
||||||
|
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def downsample_2d(
|
def downsample_2d(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kernel: Optional[torch.Tensor] = None,
|
kernel: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@@ -78,6 +78,53 @@ def get_timestep_embedding(
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_3d_sincos_pos_embed(
|
||||||
|
embed_dim: int,
|
||||||
|
spatial_size: Union[int, Tuple[int, int]],
|
||||||
|
temporal_size: int,
|
||||||
|
spatial_interpolation_scale: float = 1.0,
|
||||||
|
temporal_interpolation_scale: float = 1.0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
embed_dim (`int`):
|
||||||
|
spatial_size (`int` or `Tuple[int, int]`):
|
||||||
|
temporal_size (`int`):
|
||||||
|
spatial_interpolation_scale (`float`, defaults to 1.0):
|
||||||
|
temporal_interpolation_scale (`float`, defaults to 1.0):
|
||||||
|
"""
|
||||||
|
if embed_dim % 4 != 0:
|
||||||
|
raise ValueError("`embed_dim` must be divisible by 4")
|
||||||
|
if isinstance(spatial_size, int):
|
||||||
|
spatial_size = (spatial_size, spatial_size)
|
||||||
|
|
||||||
|
embed_dim_spatial = 3 * embed_dim // 4
|
||||||
|
embed_dim_temporal = embed_dim // 4
|
||||||
|
|
||||||
|
# 1. Spatial
|
||||||
|
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
|
||||||
|
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
|
||||||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
|
||||||
|
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
||||||
|
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
||||||
|
|
||||||
|
# 2. Temporal
|
||||||
|
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
|
||||||
|
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
||||||
|
|
||||||
|
# 3. Concat
|
||||||
|
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
||||||
|
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
|
||||||
|
|
||||||
|
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
||||||
|
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
|
||||||
|
|
||||||
|
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed(
|
def get_2d_sincos_pos_embed(
|
||||||
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
||||||
):
|
):
|
||||||
@@ -287,6 +334,46 @@ class LuminaPatchEmbed(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXPatchEmbed(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 16,
|
||||||
|
embed_dim: int = 1920,
|
||||||
|
text_embed_dim: int = 4096,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(
|
||||||
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||||
|
)
|
||||||
|
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
text_embeds (`torch.Tensor`):
|
||||||
|
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
||||||
|
image_embeds (`torch.Tensor`):
|
||||||
|
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
||||||
|
"""
|
||||||
|
text_embeds = self.text_proj(text_embeds)
|
||||||
|
|
||||||
|
batch, num_frames, channels, height, width = image_embeds.shape
|
||||||
|
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||||
|
image_embeds = self.proj(image_embeds)
|
||||||
|
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
||||||
|
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||||
|
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
||||||
|
|
||||||
|
embeds = torch.cat(
|
||||||
|
[text_embeds, image_embeds], dim=1
|
||||||
|
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
|
||||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||||
"""
|
"""
|
||||||
RoPE for image tokens with 2d structure.
|
RoPE for image tokens with 2d structure.
|
||||||
|
|||||||
@@ -37,16 +37,44 @@ class AdaLayerNorm(nn.Module):
|
|||||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim: int, num_embeddings: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_embeddings: Optional[int] = None,
|
||||||
|
output_dim: Optional[int] = None,
|
||||||
|
norm_elementwise_affine: bool = False,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
chunk_dim: int = 0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
|
||||||
self.silu = nn.SiLU()
|
|
||||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
|
||||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
self.chunk_dim = chunk_dim
|
||||||
emb = self.linear(self.silu(self.emb(timestep)))
|
|
||||||
scale, shift = torch.chunk(emb, 2)
|
if num_embeddings is not None:
|
||||||
|
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||||
|
else:
|
||||||
|
self.emb = None
|
||||||
|
|
||||||
|
output_dim = output_dim or embedding_dim * 2
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||||
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.emb is not None:
|
||||||
|
temb = self.emb(timestep)
|
||||||
|
|
||||||
|
temb = self.linear(self.silu(temb))
|
||||||
|
if self.chunk_dim == 1:
|
||||||
|
shift, scale = temb.chunk(2, dim=1)
|
||||||
|
shift = shift[:, None, :]
|
||||||
|
scale = scale[:, None, :]
|
||||||
|
else:
|
||||||
|
scale, shift = temb.chunk(2, dim=0)
|
||||||
|
|
||||||
x = self.norm(x) * (1 + scale) + shift
|
x = self.norm(x) * (1 + scale) + shift
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -321,6 +349,30 @@ class LuminaLayerNormContinuous(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXLayerNormZero(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conditioning_dim: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
elementwise_affine: bool = True,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
||||||
|
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||||
|
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||||
|
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
||||||
|
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
||||||
|
|
||||||
|
|
||||||
if is_torch_version(">=", "2.1.0"):
|
if is_torch_version(">=", "2.1.0"):
|
||||||
LayerNorm = nn.LayerNorm
|
LayerNorm = nn.LayerNorm
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from ...utils import is_torch_available
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
|
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
|
||||||
|
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||||
from .dit_transformer_2d import DiTTransformer2DModel
|
from .dit_transformer_2d import DiTTransformer2DModel
|
||||||
from .dual_transformer_2d import DualTransformer2DModel
|
from .dual_transformer_2d import DualTransformer2DModel
|
||||||
from .hunyuan_transformer_2d import HunyuanDiT2DModel
|
from .hunyuan_transformer_2d import HunyuanDiT2DModel
|
||||||
|
|||||||
352
src/diffusers/models/transformers/cogvideox_transformer_3d.py
Normal file
352
src/diffusers/models/transformers/cogvideox_transformer_3d.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ...utils import is_torch_version, logging
|
||||||
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
|
from ..attention import Attention, FeedForward
|
||||||
|
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||||
|
from ..modeling_outputs import Transformer2DModelOutput
|
||||||
|
from ..modeling_utils import ModelMixin
|
||||||
|
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_allow_in_graph
|
||||||
|
class CogVideoXBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
Transformer block used in CogVideoX model. TODO: add link to CogVideoX upon release
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
dim (`int`): The number of channels in the input and output.
|
||||||
|
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`): The number of channels in each head.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||||
|
num_embeds_ada_norm (:
|
||||||
|
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||||
|
attention_bias (:
|
||||||
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||||
|
only_cross_attention (`bool`, *optional*):
|
||||||
|
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||||
|
double_self_attention (`bool`, *optional*):
|
||||||
|
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||||
|
upcast_attention (`bool`, *optional*):
|
||||||
|
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||||
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use learnable elementwise affine parameters for normalization.
|
||||||
|
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||||
|
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||||
|
final_dropout (`bool` *optional*, defaults to False):
|
||||||
|
Whether to apply a final dropout after the last feed-forward layer.
|
||||||
|
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||||
|
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||||
|
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of positional embeddings to apply to.
|
||||||
|
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||||
|
The maximum number of positional embeddings to apply.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
time_embed_dim: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation_fn: str = "gelu-approximate",
|
||||||
|
attention_bias: bool = False,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
norm_elementwise_affine: bool = True,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
final_dropout: bool = True,
|
||||||
|
ff_inner_dim: Optional[int] = None,
|
||||||
|
ff_bias: bool = True,
|
||||||
|
attention_out_bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# 1. Self Attention
|
||||||
|
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||||
|
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
qk_norm="layer_norm" if qk_norm else None,
|
||||||
|
eps=1e-6,
|
||||||
|
bias=attention_bias,
|
||||||
|
out_bias=attention_out_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Feed Forward
|
||||||
|
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||||
|
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
final_dropout=final_dropout,
|
||||||
|
inner_dim=ff_inner_dim,
|
||||||
|
bias=ff_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
temb: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||||
|
hidden_states, encoder_hidden_states, temb
|
||||||
|
)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
text_length = norm_encoder_hidden_states.size(1)
|
||||||
|
|
||||||
|
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
|
||||||
|
# them in cross-attention individually
|
||||||
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||||
|
attn_output = self.attn1(
|
||||||
|
hidden_states=norm_hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
|
||||||
|
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
|
||||||
|
|
||||||
|
# norm & modulate
|
||||||
|
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||||
|
hidden_states, encoder_hidden_states, temb
|
||||||
|
)
|
||||||
|
|
||||||
|
# feed-forward
|
||||||
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
|
||||||
|
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
|
||||||
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
A Transformer model for video-like data in CogVideoX. TODO: add link to CogVideoX upon release
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||||
|
in_channels (`int`, *optional*):
|
||||||
|
The number of channels in the input.
|
||||||
|
out_channels (`int`, *optional*):
|
||||||
|
The number of channels in the output.
|
||||||
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||||
|
attention_bias (`bool`, *optional*):
|
||||||
|
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||||
|
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||||
|
This is fixed during training since it is used to learn a number of position embeddings.
|
||||||
|
patch_size (`int`, *optional*):
|
||||||
|
The size of the patches to use in the patch embedding layer.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||||
|
num_embeds_ada_norm ( `int`, *optional*):
|
||||||
|
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||||
|
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||||
|
added to the hidden states. During inference, you can denoise for up to but not more steps than
|
||||||
|
`num_embeds_ada_norm`.
|
||||||
|
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||||
|
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
|
||||||
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to use elementwise affine in normalization layers.
|
||||||
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
|
||||||
|
caption_channels (`int`, *optional*):
|
||||||
|
The number of channels in the caption embeddings.
|
||||||
|
video_length (`int`, *optional*):
|
||||||
|
The number of frames in the video-like data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads: int = 30,
|
||||||
|
attention_head_dim: int = 64,
|
||||||
|
in_channels: Optional[int] = 16,
|
||||||
|
out_channels: Optional[int] = 16,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
freq_shift: int = 0,
|
||||||
|
time_embed_dim: int = 512,
|
||||||
|
text_embed_dim: int = 4096,
|
||||||
|
num_layers: int = 30,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
attention_bias: bool = True,
|
||||||
|
sample_width: int = 90,
|
||||||
|
sample_height: int = 60,
|
||||||
|
sample_frames: int = 49,
|
||||||
|
patch_size: int = 2,
|
||||||
|
temporal_compression_ratio: int = 4,
|
||||||
|
max_text_seq_length: int = 226,
|
||||||
|
activation_fn: str = "gelu-approximate",
|
||||||
|
timestep_activation_fn: str = "silu",
|
||||||
|
norm_elementwise_affine: bool = True,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
spatial_interpolation_scale: float = 1.875,
|
||||||
|
temporal_interpolation_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
|
post_patch_height = sample_height // patch_size
|
||||||
|
post_patch_width = sample_width // patch_size
|
||||||
|
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
||||||
|
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||||
|
|
||||||
|
# 1. Patch embedding
|
||||||
|
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
||||||
|
self.embedding_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
# 2. 3D positional embeddings
|
||||||
|
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
||||||
|
inner_dim,
|
||||||
|
(post_patch_width, post_patch_height),
|
||||||
|
post_time_compression_frames,
|
||||||
|
spatial_interpolation_scale,
|
||||||
|
temporal_interpolation_scale,
|
||||||
|
)
|
||||||
|
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
||||||
|
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
||||||
|
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
||||||
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||||
|
|
||||||
|
# 3. Time embeddings
|
||||||
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||||
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||||
|
|
||||||
|
# 4. Define spatio-temporal transformers blocks
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
CogVideoXBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
norm_elementwise_affine=norm_elementwise_affine,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
|
# 5. Output blocks
|
||||||
|
self.norm_out = AdaLayerNorm(
|
||||||
|
embedding_dim=time_embed_dim,
|
||||||
|
output_dim=2 * inner_dim,
|
||||||
|
norm_elementwise_affine=norm_elementwise_affine,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
chunk_dim=1,
|
||||||
|
)
|
||||||
|
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
self.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
timestep: Union[int, float, torch.LongTensor],
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||||
|
|
||||||
|
# 1. Time embedding
|
||||||
|
timesteps = timestep
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
|
||||||
|
# 2. Patch embedding
|
||||||
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||||
|
|
||||||
|
# 3. Position embedding
|
||||||
|
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||||
|
|
||||||
|
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
|
||||||
|
hidden_states = hidden_states + pos_embeds
|
||||||
|
hidden_states = self.embedding_dropout(hidden_states)
|
||||||
|
|
||||||
|
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
|
||||||
|
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
|
||||||
|
|
||||||
|
# 5. Transformer blocks
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states,
|
||||||
|
emb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, encoder_hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
temb=emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.norm_final(hidden_states)
|
||||||
|
|
||||||
|
# 6. Final block
|
||||||
|
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
# 7. Unpatchify
|
||||||
|
p = self.config.patch_size
|
||||||
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||||
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (output,)
|
||||||
|
return Transformer2DModelOutput(sample=output)
|
||||||
@@ -348,6 +348,70 @@ class KUpsample2D(nn.Module):
|
|||||||
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXUpsample3D(nn.Module):
|
||||||
|
r"""
|
||||||
|
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`):
|
||||||
|
Number of channels in the input image.
|
||||||
|
out_channels (`int`):
|
||||||
|
Number of channels produced by the convolution.
|
||||||
|
kernel_size (`int`, defaults to `3`):
|
||||||
|
Size of the convolving kernel.
|
||||||
|
stride (`int`, defaults to `1`):
|
||||||
|
Stride of the convolution.
|
||||||
|
padding (`int`, defaults to `1`):
|
||||||
|
Padding added to all four sides of the input.
|
||||||
|
compress_time (`bool`, defaults to `False`):
|
||||||
|
Whether or not to compress the time dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: int = 1,
|
||||||
|
compress_time: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
|
self.compress_time = compress_time
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.compress_time:
|
||||||
|
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
||||||
|
# split first frame
|
||||||
|
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
||||||
|
|
||||||
|
x_first = F.interpolate(x_first, scale_factor=2.0)
|
||||||
|
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
||||||
|
x_first = x_first[:, :, None, :, :]
|
||||||
|
inputs = torch.cat([x_first, x_rest], dim=2)
|
||||||
|
elif inputs.shape[2] > 1:
|
||||||
|
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||||
|
else:
|
||||||
|
inputs = inputs.squeeze(2)
|
||||||
|
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||||
|
inputs = inputs[:, :, None, :, :]
|
||||||
|
else:
|
||||||
|
# only interpolate 2D
|
||||||
|
b, c, t, h, w = inputs.shape
|
||||||
|
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||||
|
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||||
|
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
b, c, t, h, w = inputs.shape
|
||||||
|
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||||
|
inputs = self.conv(inputs)
|
||||||
|
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d_native(
|
def upfirdn2d_native(
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
kernel: torch.Tensor,
|
kernel: torch.Tensor,
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ else:
|
|||||||
"AudioLDM2UNet2DConditionModel",
|
"AudioLDM2UNet2DConditionModel",
|
||||||
]
|
]
|
||||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||||
|
_import_structure["cogvideo"] = ["CogVideoXPipeline"]
|
||||||
_import_structure["controlnet"].extend(
|
_import_structure["controlnet"].extend(
|
||||||
[
|
[
|
||||||
"BlipDiffusionControlNetPipeline",
|
"BlipDiffusionControlNetPipeline",
|
||||||
@@ -438,6 +439,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
)
|
)
|
||||||
from .aura_flow import AuraFlowPipeline
|
from .aura_flow import AuraFlowPipeline
|
||||||
from .blip_diffusion import BlipDiffusionPipeline
|
from .blip_diffusion import BlipDiffusionPipeline
|
||||||
|
from .cogvideo import CogVideoXPipeline
|
||||||
from .controlnet import (
|
from .controlnet import (
|
||||||
BlipDiffusionControlNetPipeline,
|
BlipDiffusionControlNetPipeline,
|
||||||
StableDiffusionControlNetImg2ImgPipeline,
|
StableDiffusionControlNetImg2ImgPipeline,
|
||||||
|
|||||||
48
src/diffusers/pipelines/cogvideo/__init__.py
Normal file
48
src/diffusers/pipelines/cogvideo/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_dummy_objects = {}
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils.dummy_torch_and_transformers_objects import *
|
||||||
|
else:
|
||||||
|
from .pipeline_cogvideox import CogVideoXPipeline
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
686
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Normal file
686
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Normal file
@@ -0,0 +1,686 @@
|
|||||||
|
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||||
|
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||||
|
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||||
|
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||||
|
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ...video_processor import VideoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
EXAMPLE_DOC_STRING = """
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> from diffusers import CogVideoXPipeline
|
||||||
|
>>> from diffusers.utils import export_to_video
|
||||||
|
|
||||||
|
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda")
|
||||||
|
>>> prompt = (
|
||||||
|
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||||
|
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||||
|
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||||
|
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||||
|
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||||
|
... "atmosphere of this unique musical performance."
|
||||||
|
... )
|
||||||
|
>>> video = pipe(
|
||||||
|
... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=20
|
||||||
|
... ).frames[0]
|
||||||
|
>>> export_to_video(video, "output.mp4", fps=8)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CogVideoXPipelineOutput(BaseOutput):
|
||||||
|
r"""
|
||||||
|
Output class for CogVideo pipelines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||||
|
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||||
|
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||||
|
`(batch_size, num_frames, channels, height, width)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frames: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXPipeline(DiffusionPipeline):
|
||||||
|
r"""
|
||||||
|
Pipeline for text-to-video generation using CogVideoX.
|
||||||
|
|
||||||
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vae ([`AutoencoderKL`]):
|
||||||
|
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||||
|
text_encoder ([`T5EncoderModel`]):
|
||||||
|
Frozen text-encoder. CogVideoX uses
|
||||||
|
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||||
|
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||||
|
tokenizer (`T5Tokenizer`):
|
||||||
|
Tokenizer of class
|
||||||
|
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||||
|
transformer ([`CogVideoXTransformer3DModel`]):
|
||||||
|
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
||||||
|
scheduler ([`SchedulerMixin`]):
|
||||||
|
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_optional_components = ["tokenizer", "text_encoder"]
|
||||||
|
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||||
|
|
||||||
|
_callback_tensor_inputs = [
|
||||||
|
"latents",
|
||||||
|
"prompt_embeds",
|
||||||
|
"negative_prompt_embeds",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: T5Tokenizer,
|
||||||
|
text_encoder: T5EncoderModel,
|
||||||
|
vae: AutoencoderKLCogVideoX,
|
||||||
|
transformer: CogVideoXTransformer3DModel,
|
||||||
|
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||||
|
)
|
||||||
|
self.vae_scale_factor_spatial = (
|
||||||
|
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||||
|
)
|
||||||
|
self.vae_scale_factor_temporal = (
|
||||||
|
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||||
|
)
|
||||||
|
self.tokenizer_max_length = (
|
||||||
|
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226
|
||||||
|
)
|
||||||
|
|
||||||
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||||
|
|
||||||
|
def _get_t5_prompt_embeds(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
num_videos_per_prompt: int = 1,
|
||||||
|
max_sequence_length: int = 226,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
device = device or self._execution_device
|
||||||
|
dtype = dtype or self.text_encoder.dtype
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
batch_size = len(prompt)
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||||
|
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||||
|
logger.warning(
|
||||||
|
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||||
|
f" {max_sequence_length} tokens: {removed_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
|
_, seq_len, _ = prompt_embeds.shape
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
do_classifier_free_guidance: bool = True,
|
||||||
|
num_videos_per_prompt: int = 1,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
max_sequence_length: int = 226,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Encodes the prompt into text encoder hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
prompt to be encoded
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use classifier free guidance or not.
|
||||||
|
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||||
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
device: (`torch.device`, *optional*):
|
||||||
|
torch device
|
||||||
|
dtype: (`torch.dtype`, *optional*):
|
||||||
|
torch dtype
|
||||||
|
"""
|
||||||
|
device = device or self._execution_device
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
if prompt is not None:
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_embeds = self._get_t5_prompt_embeds(
|
||||||
|
prompt=prompt,
|
||||||
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||||
|
negative_prompt = negative_prompt or ""
|
||||||
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||||
|
|
||||||
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||||
|
raise TypeError(
|
||||||
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
|
f" {type(prompt)}."
|
||||||
|
)
|
||||||
|
elif batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||||
|
" the batch size of `prompt`."
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||||
|
prompt=negative_prompt,
|
||||||
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_embeds, negative_prompt_embeds
|
||||||
|
|
||||||
|
def prepare_latents(
|
||||||
|
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||||
|
):
|
||||||
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||||
|
num_channels_latents,
|
||||||
|
height // self.vae_scale_factor_spatial,
|
||||||
|
width // self.vae_scale_factor_spatial,
|
||||||
|
)
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device)
|
||||||
|
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
latents = latents * self.scheduler.init_noise_sigma
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
|
||||||
|
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||||
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for i in range(num_seconds):
|
||||||
|
# Whether or not to clear fake context parallel cache
|
||||||
|
fake_cp = i + 1 < num_seconds
|
||||||
|
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
||||||
|
|
||||||
|
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame], fake_cp=fake_cp).sample
|
||||||
|
frames.append(current_frames)
|
||||||
|
|
||||||
|
frames = torch.cat(frames, dim=2)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||||
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
|
# and should be between [0, 1]
|
||||||
|
|
||||||
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
if accepts_eta:
|
||||||
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
# check if the scheduler accepts generator
|
||||||
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
if accepts_generator:
|
||||||
|
extra_step_kwargs["generator"] = generator
|
||||||
|
return extra_step_kwargs
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||||
|
def check_inputs(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
negative_prompt,
|
||||||
|
callback_on_step_end_tensor_inputs,
|
||||||
|
prompt_embeds=None,
|
||||||
|
negative_prompt_embeds=None,
|
||||||
|
):
|
||||||
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
|
|
||||||
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||||
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||||
|
)
|
||||||
|
if prompt is not None and prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||||
|
" only forward one of the two."
|
||||||
|
)
|
||||||
|
elif prompt is None and prompt_embeds is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||||
|
)
|
||||||
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
if prompt is not None and negative_prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||||
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||||
|
)
|
||||||
|
|
||||||
|
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||||
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||||
|
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||||
|
raise ValueError(
|
||||||
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||||
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||||
|
f" {negative_prompt_embeds.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guidance_scale(self):
|
||||||
|
return self._guidance_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_timesteps(self):
|
||||||
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interrupt(self):
|
||||||
|
return self._interrupt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
height: int = 480,
|
||||||
|
width: int = 720,
|
||||||
|
num_frames: int = 48,
|
||||||
|
fps: int = 8,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
guidance_scale: float = 6,
|
||||||
|
num_videos_per_prompt: int = 1,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: str = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback_on_step_end: Optional[
|
||||||
|
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||||
|
] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
use_dynamic_cfg: bool = False,
|
||||||
|
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||||
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||||
|
num_frames (`int`, defaults to `48`):
|
||||||
|
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
||||||
|
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
||||||
|
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
||||||
|
needs to be satisfied is that of divisibility mentioned above.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||||
|
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||||
|
passed will be used. Must be in descending order.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of videos to generate per prompt.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
|
to make generation deterministic.
|
||||||
|
latents (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||||
|
of a plain tuple.
|
||||||
|
callback_on_step_end (`Callable`, *optional*):
|
||||||
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||||
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||||
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||||
|
`callback_on_step_end_tensor_inputs`.
|
||||||
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||||
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
||||||
|
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
num_frames <= 48 and num_frames % fps == 0 and fps == 8
|
||||||
|
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
|
||||||
|
|
||||||
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||||
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||||
|
|
||||||
|
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||||
|
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||||
|
num_videos_per_prompt = 1
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
negative_prompt,
|
||||||
|
callback_on_step_end_tensor_inputs,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
)
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
# 2. Default call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||||
|
prompt,
|
||||||
|
negative_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# 5. Prepare latents.
|
||||||
|
latent_channels = self.transformer.config.in_channels
|
||||||
|
num_frames += 1
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_videos_per_prompt,
|
||||||
|
latent_channels,
|
||||||
|
num_frames,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
# 7. Denoising loop
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
# for DPM-solver++
|
||||||
|
old_pred_original_sample = None
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
|
# predict noise model_output
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
timestep=timestep,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if use_dynamic_cfg:
|
||||||
|
self._guidance_scale = 1 + guidance_scale * (
|
||||||
|
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||||
|
)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
else:
|
||||||
|
latents, old_pred_original_sample = self.scheduler.step(
|
||||||
|
noise_pred,
|
||||||
|
old_pred_original_sample,
|
||||||
|
t,
|
||||||
|
timesteps[i - 1] if i > 0 else None,
|
||||||
|
latents,
|
||||||
|
**extra_step_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
latents = latents.to(prompt_embeds.dtype)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||||
|
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if not output_type == "latents":
|
||||||
|
video = self.decode_latents(latents, num_frames // fps)
|
||||||
|
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||||
|
else:
|
||||||
|
video = latents
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (video,)
|
||||||
|
|
||||||
|
return CogVideoXPipelineOutput(frames=video)
|
||||||
@@ -43,12 +43,14 @@ else:
|
|||||||
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
||||||
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
||||||
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
||||||
|
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
|
||||||
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
|
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
|
||||||
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
|
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
|
||||||
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
|
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
|
||||||
_import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"]
|
_import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"]
|
||||||
_import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"]
|
_import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"]
|
||||||
_import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"]
|
_import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"]
|
||||||
|
_import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"]
|
||||||
_import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"]
|
_import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"]
|
||||||
_import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"]
|
_import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"]
|
||||||
_import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"]
|
_import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"]
|
||||||
@@ -141,12 +143,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
||||||
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
||||||
from .scheduling_ddim import DDIMScheduler
|
from .scheduling_ddim import DDIMScheduler
|
||||||
|
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
|
||||||
from .scheduling_ddim_inverse import DDIMInverseScheduler
|
from .scheduling_ddim_inverse import DDIMInverseScheduler
|
||||||
from .scheduling_ddim_parallel import DDIMParallelScheduler
|
from .scheduling_ddim_parallel import DDIMParallelScheduler
|
||||||
from .scheduling_ddpm import DDPMScheduler
|
from .scheduling_ddpm import DDPMScheduler
|
||||||
from .scheduling_ddpm_parallel import DDPMParallelScheduler
|
from .scheduling_ddpm_parallel import DDPMParallelScheduler
|
||||||
from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler
|
from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler
|
||||||
from .scheduling_deis_multistep import DEISMultistepScheduler
|
from .scheduling_deis_multistep import DEISMultistepScheduler
|
||||||
|
from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler
|
||||||
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||||
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
|
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
|
||||||
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
|
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
|
||||||
|
|||||||
450
src/diffusers/schedulers/scheduling_ddim_cogvideox.py
Normal file
450
src/diffusers/schedulers/scheduling_ddim_cogvideox.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||||||
|
# and https://github.com/hojonathanho/diffusion
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..utils import BaseOutput
|
||||||
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->CogVideoXDDIM
|
||||||
|
class CogVideoXDDIMSchedulerOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output class for the scheduler's `step` function output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||||
|
denoising loop.
|
||||||
|
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||||
|
`pred_original_sample` can be used to preview progress or for guidance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prev_sample: torch.Tensor
|
||||||
|
pred_original_sample: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||||
|
def betas_for_alpha_bar(
|
||||||
|
num_diffusion_timesteps,
|
||||||
|
max_beta=0.999,
|
||||||
|
alpha_transform_type="cosine",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||||
|
(1-beta) over time from t = [0,1].
|
||||||
|
|
||||||
|
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||||
|
to that part of the diffusion process.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||||
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||||
|
prevent singularities.
|
||||||
|
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||||
|
Choose from `cosine` or `exp`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||||
|
"""
|
||||||
|
if alpha_transform_type == "cosine":
|
||||||
|
|
||||||
|
def alpha_bar_fn(t):
|
||||||
|
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||||
|
|
||||||
|
elif alpha_transform_type == "exp":
|
||||||
|
|
||||||
|
def alpha_bar_fn(t):
|
||||||
|
return math.exp(t * -12.0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
||||||
|
|
||||||
|
betas = []
|
||||||
|
for i in range(num_diffusion_timesteps):
|
||||||
|
t1 = i / num_diffusion_timesteps
|
||||||
|
t2 = (i + 1) / num_diffusion_timesteps
|
||||||
|
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||||
|
return torch.tensor(betas, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr(alphas_cumprod):
|
||||||
|
"""
|
||||||
|
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
betas (`torch.Tensor`):
|
||||||
|
the betas that the scheduler is being initialized with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: rescaled betas with zero terminal SNR
|
||||||
|
"""
|
||||||
|
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
|
||||||
|
return alphas_bar
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
||||||
|
non-Markovian guidance.
|
||||||
|
|
||||||
|
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||||
|
methods the library implements for all schedulers such as loading and saving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_train_timesteps (`int`, defaults to 1000):
|
||||||
|
The number of diffusion steps to train the model.
|
||||||
|
beta_start (`float`, defaults to 0.0001):
|
||||||
|
The starting `beta` value of inference.
|
||||||
|
beta_end (`float`, defaults to 0.02):
|
||||||
|
The final `beta` value.
|
||||||
|
beta_schedule (`str`, defaults to `"linear"`):
|
||||||
|
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||||
|
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||||
|
trained_betas (`np.ndarray`, *optional*):
|
||||||
|
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||||
|
clip_sample (`bool`, defaults to `True`):
|
||||||
|
Clip the predicted sample for numerical stability.
|
||||||
|
clip_sample_range (`float`, defaults to 1.0):
|
||||||
|
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||||
|
set_alpha_to_one (`bool`, defaults to `True`):
|
||||||
|
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||||
|
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||||
|
otherwise it uses the alpha value at step 0.
|
||||||
|
steps_offset (`int`, defaults to 0):
|
||||||
|
An offset added to the inference steps, as required by some model families.
|
||||||
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||||
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||||
|
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||||
|
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||||
|
thresholding (`bool`, defaults to `False`):
|
||||||
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||||
|
as Stable Diffusion.
|
||||||
|
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||||
|
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||||
|
sample_max_value (`float`, defaults to 1.0):
|
||||||
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||||
|
timestep_spacing (`str`, defaults to `"leading"`):
|
||||||
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||||
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||||
|
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||||
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||||
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||||
|
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||||
|
"""
|
||||||
|
|
||||||
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||||
|
order = 1
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_train_timesteps: int = 1000,
|
||||||
|
beta_start: float = 0.00085,
|
||||||
|
beta_end: float = 0.0120,
|
||||||
|
beta_schedule: str = "scaled_linear",
|
||||||
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||||
|
clip_sample: bool = True,
|
||||||
|
set_alpha_to_one: bool = True,
|
||||||
|
steps_offset: int = 0,
|
||||||
|
prediction_type: str = "epsilon",
|
||||||
|
clip_sample_range: float = 1.0,
|
||||||
|
sample_max_value: float = 1.0,
|
||||||
|
timestep_spacing: str = "leading",
|
||||||
|
rescale_betas_zero_snr: bool = False,
|
||||||
|
snr_shift_scale: float = 3.0,
|
||||||
|
):
|
||||||
|
if trained_betas is not None:
|
||||||
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||||
|
elif beta_schedule == "scaled_linear":
|
||||||
|
# this schedule is very specific to the latent diffusion model.
|
||||||
|
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
|
||||||
|
elif beta_schedule == "squaredcos_cap_v2":
|
||||||
|
# Glide cosine schedule
|
||||||
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
self.alphas = 1.0 - self.betas
|
||||||
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||||
|
|
||||||
|
# Modify: SNR shift following SD3
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
|
||||||
|
|
||||||
|
# Rescale for zero SNR
|
||||||
|
if rescale_betas_zero_snr:
|
||||||
|
self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
|
||||||
|
|
||||||
|
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||||
|
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||||
|
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||||
|
# whether we use the final alpha of the "non-previous" one.
|
||||||
|
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||||
|
|
||||||
|
# standard deviation of the initial noise distribution
|
||||||
|
self.init_noise_sigma = 1.0
|
||||||
|
|
||||||
|
# setable values
|
||||||
|
self.num_inference_steps = None
|
||||||
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||||
|
|
||||||
|
def _get_variance(self, timestep, prev_timestep):
|
||||||
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||||
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||||
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||||
|
|
||||||
|
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||||
|
|
||||||
|
return variance
|
||||||
|
|
||||||
|
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||||
|
current timestep.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.Tensor`):
|
||||||
|
The input sample.
|
||||||
|
timestep (`int`, *optional*):
|
||||||
|
The current timestep in the diffusion chain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
A scaled input sample.
|
||||||
|
"""
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||||
|
"""
|
||||||
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if num_inference_steps > self.config.num_train_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||||
|
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||||
|
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||||
|
if self.config.timestep_spacing == "linspace":
|
||||||
|
timesteps = (
|
||||||
|
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
|
||||||
|
.round()[::-1]
|
||||||
|
.copy()
|
||||||
|
.astype(np.int64)
|
||||||
|
)
|
||||||
|
elif self.config.timestep_spacing == "leading":
|
||||||
|
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||||
|
# creates integer timesteps by multiplying by ratio
|
||||||
|
# casting to int to avoid issues when num_inference_step is power of 3
|
||||||
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||||
|
timesteps += self.config.steps_offset
|
||||||
|
elif self.config.timestep_spacing == "trailing":
|
||||||
|
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||||
|
# creates integer timesteps by multiplying by ratio
|
||||||
|
# casting to int to avoid issues when num_inference_step is power of 3
|
||||||
|
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
|
||||||
|
timesteps -= 1
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: torch.Tensor,
|
||||||
|
timestep: int,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
eta: float = 0.0,
|
||||||
|
use_clipped_model_output: bool = False,
|
||||||
|
generator=None,
|
||||||
|
variance_noise: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[CogVideoXDDIMSchedulerOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||||
|
process from the learned model outputs (most often the predicted noise).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output (`torch.Tensor`):
|
||||||
|
The direct output from learned diffusion model.
|
||||||
|
timestep (`float`):
|
||||||
|
The current discrete timestep in the diffusion chain.
|
||||||
|
sample (`torch.Tensor`):
|
||||||
|
A current instance of a sample created by the diffusion process.
|
||||||
|
eta (`float`):
|
||||||
|
The weight of noise for added noise in diffusion step.
|
||||||
|
use_clipped_model_output (`bool`, defaults to `False`):
|
||||||
|
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
||||||
|
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
||||||
|
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
||||||
|
`use_clipped_model_output` has no effect.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A random number generator.
|
||||||
|
variance_noise (`torch.Tensor`):
|
||||||
|
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||||
|
itself. Useful for methods such as [`CycleDiffusion`].
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] or
|
||||||
|
`tuple`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] or `tuple`:
|
||||||
|
If return_dict is `True`, [`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] is
|
||||||
|
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.num_inference_steps is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||||
|
)
|
||||||
|
|
||||||
|
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
# Ideally, read DDIM paper in-detail understanding
|
||||||
|
|
||||||
|
# Notation (<variable name> -> <name in paper>
|
||||||
|
# - pred_noise_t -> e_theta(x_t, t)
|
||||||
|
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
||||||
|
# - std_dev_t -> sigma_t
|
||||||
|
# - eta -> η
|
||||||
|
# - pred_sample_direction -> "direction pointing to x_t"
|
||||||
|
# - pred_prev_sample -> "x_t-1"
|
||||||
|
|
||||||
|
# 1. get previous step value (=t-1)
|
||||||
|
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||||
|
|
||||||
|
# 2. compute alphas, betas
|
||||||
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||||
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||||
|
|
||||||
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
|
||||||
|
# 3. compute predicted original sample from predicted noise also called
|
||||||
|
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
|
||||||
|
if self.config.prediction_type == "epsilon":
|
||||||
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
|
# pred_epsilon = model_output
|
||||||
|
elif self.config.prediction_type == "sample":
|
||||||
|
pred_original_sample = model_output
|
||||||
|
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||||
|
elif self.config.prediction_type == "v_prediction":
|
||||||
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||||
|
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||||
|
" `v_prediction`"
|
||||||
|
)
|
||||||
|
|
||||||
|
a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
|
||||||
|
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t
|
||||||
|
|
||||||
|
prev_sample = a_t * sample + b_t * pred_original_sample
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (prev_sample,)
|
||||||
|
|
||||||
|
return CogVideoXDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||||
|
def add_noise(
|
||||||
|
self,
|
||||||
|
original_samples: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
timesteps: torch.IntTensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||||
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
||||||
|
# for the subsequent add_noise calls
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
||||||
|
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||||
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
|
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||||
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||||
|
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||||
|
return noisy_samples
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||||||
|
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
||||||
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
||||||
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
||||||
|
timesteps = timesteps.to(sample.device)
|
||||||
|
|
||||||
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||||
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
|
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||||||
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||||
|
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||||||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||||
|
return velocity
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.config.num_train_timesteps
|
||||||
481
src/diffusers/schedulers/scheduling_dpm_cogvideox.py
Normal file
481
src/diffusers/schedulers/scheduling_dpm_cogvideox.py
Normal file
@@ -0,0 +1,481 @@
|
|||||||
|
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||||||
|
# and https://github.com/hojonathanho/diffusion
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..utils import BaseOutput
|
||||||
|
from ..utils.torch_utils import randn_tensor
|
||||||
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->CogVideoX
|
||||||
|
class CogVideoXDPMSchedulerOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output class for the scheduler's `step` function output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||||
|
denoising loop.
|
||||||
|
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||||
|
`pred_original_sample` can be used to preview progress or for guidance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prev_sample: torch.Tensor
|
||||||
|
pred_original_sample: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||||
|
def betas_for_alpha_bar(
|
||||||
|
num_diffusion_timesteps,
|
||||||
|
max_beta=0.999,
|
||||||
|
alpha_transform_type="cosine",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||||
|
(1-beta) over time from t = [0,1].
|
||||||
|
|
||||||
|
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||||
|
to that part of the diffusion process.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||||
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||||
|
prevent singularities.
|
||||||
|
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||||
|
Choose from `cosine` or `exp`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||||
|
"""
|
||||||
|
if alpha_transform_type == "cosine":
|
||||||
|
|
||||||
|
def alpha_bar_fn(t):
|
||||||
|
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||||
|
|
||||||
|
elif alpha_transform_type == "exp":
|
||||||
|
|
||||||
|
def alpha_bar_fn(t):
|
||||||
|
return math.exp(t * -12.0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
||||||
|
|
||||||
|
betas = []
|
||||||
|
for i in range(num_diffusion_timesteps):
|
||||||
|
t1 = i / num_diffusion_timesteps
|
||||||
|
t2 = (i + 1) / num_diffusion_timesteps
|
||||||
|
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||||
|
return torch.tensor(betas, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddim_cogvideox.rescale_zero_terminal_snr
|
||||||
|
def rescale_zero_terminal_snr(alphas_cumprod):
|
||||||
|
"""
|
||||||
|
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
betas (`torch.Tensor`):
|
||||||
|
the betas that the scheduler is being initialized with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: rescaled betas with zero terminal SNR
|
||||||
|
"""
|
||||||
|
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
|
||||||
|
return alphas_bar
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
||||||
|
non-Markovian guidance.
|
||||||
|
|
||||||
|
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||||
|
methods the library implements for all schedulers such as loading and saving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_train_timesteps (`int`, defaults to 1000):
|
||||||
|
The number of diffusion steps to train the model.
|
||||||
|
beta_start (`float`, defaults to 0.0001):
|
||||||
|
The starting `beta` value of inference.
|
||||||
|
beta_end (`float`, defaults to 0.02):
|
||||||
|
The final `beta` value.
|
||||||
|
beta_schedule (`str`, defaults to `"linear"`):
|
||||||
|
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||||
|
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||||
|
trained_betas (`np.ndarray`, *optional*):
|
||||||
|
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||||
|
clip_sample (`bool`, defaults to `True`):
|
||||||
|
Clip the predicted sample for numerical stability.
|
||||||
|
clip_sample_range (`float`, defaults to 1.0):
|
||||||
|
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||||
|
set_alpha_to_one (`bool`, defaults to `True`):
|
||||||
|
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||||
|
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||||
|
otherwise it uses the alpha value at step 0.
|
||||||
|
steps_offset (`int`, defaults to 0):
|
||||||
|
An offset added to the inference steps, as required by some model families.
|
||||||
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||||
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||||
|
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||||
|
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||||
|
thresholding (`bool`, defaults to `False`):
|
||||||
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||||
|
as Stable Diffusion.
|
||||||
|
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||||
|
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||||
|
sample_max_value (`float`, defaults to 1.0):
|
||||||
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||||
|
timestep_spacing (`str`, defaults to `"leading"`):
|
||||||
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||||
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||||
|
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||||
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||||
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||||
|
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||||
|
"""
|
||||||
|
|
||||||
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||||
|
order = 1
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_train_timesteps: int = 1000,
|
||||||
|
beta_start: float = 0.00085,
|
||||||
|
beta_end: float = 0.0120,
|
||||||
|
beta_schedule: str = "scaled_linear",
|
||||||
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||||
|
clip_sample: bool = True,
|
||||||
|
set_alpha_to_one: bool = True,
|
||||||
|
steps_offset: int = 0,
|
||||||
|
prediction_type: str = "epsilon",
|
||||||
|
clip_sample_range: float = 1.0,
|
||||||
|
sample_max_value: float = 1.0,
|
||||||
|
timestep_spacing: str = "leading",
|
||||||
|
rescale_betas_zero_snr: bool = False,
|
||||||
|
snr_shift_scale: float = 3.0,
|
||||||
|
):
|
||||||
|
if trained_betas is not None:
|
||||||
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||||
|
elif beta_schedule == "scaled_linear":
|
||||||
|
# this schedule is very specific to the latent diffusion model.
|
||||||
|
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
|
||||||
|
elif beta_schedule == "squaredcos_cap_v2":
|
||||||
|
# Glide cosine schedule
|
||||||
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
self.alphas = 1.0 - self.betas
|
||||||
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||||
|
|
||||||
|
# Modify: SNR shift following SD3
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
|
||||||
|
|
||||||
|
# Rescale for zero SNR
|
||||||
|
if rescale_betas_zero_snr:
|
||||||
|
self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
|
||||||
|
|
||||||
|
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||||
|
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||||
|
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||||
|
# whether we use the final alpha of the "non-previous" one.
|
||||||
|
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||||
|
|
||||||
|
# standard deviation of the initial noise distribution
|
||||||
|
self.init_noise_sigma = 1.0
|
||||||
|
|
||||||
|
# setable values
|
||||||
|
self.num_inference_steps = None
|
||||||
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||||
|
|
||||||
|
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||||
|
current timestep.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.Tensor`):
|
||||||
|
The input sample.
|
||||||
|
timestep (`int`, *optional*):
|
||||||
|
The current timestep in the diffusion chain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
A scaled input sample.
|
||||||
|
"""
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||||
|
"""
|
||||||
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if num_inference_steps > self.config.num_train_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||||
|
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||||
|
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||||
|
if self.config.timestep_spacing == "linspace":
|
||||||
|
timesteps = (
|
||||||
|
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
|
||||||
|
.round()[::-1]
|
||||||
|
.copy()
|
||||||
|
.astype(np.int64)
|
||||||
|
)
|
||||||
|
elif self.config.timestep_spacing == "leading":
|
||||||
|
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||||
|
# creates integer timesteps by multiplying by ratio
|
||||||
|
# casting to int to avoid issues when num_inference_step is power of 3
|
||||||
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||||
|
timesteps += self.config.steps_offset
|
||||||
|
elif self.config.timestep_spacing == "trailing":
|
||||||
|
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||||
|
# creates integer timesteps by multiplying by ratio
|
||||||
|
# casting to int to avoid issues when num_inference_step is power of 3
|
||||||
|
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
|
||||||
|
timesteps -= 1
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||||
|
|
||||||
|
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
|
||||||
|
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
|
||||||
|
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
|
||||||
|
h = lamb_next - lamb
|
||||||
|
|
||||||
|
if alpha_prod_t_back is not None:
|
||||||
|
lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log()
|
||||||
|
h_last = lamb - lamb_previous
|
||||||
|
r = h_last / h
|
||||||
|
return h, r, lamb, lamb_next
|
||||||
|
else:
|
||||||
|
return h, None, lamb, lamb_next
|
||||||
|
|
||||||
|
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
|
||||||
|
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
|
||||||
|
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
|
||||||
|
|
||||||
|
if alpha_prod_t_back is not None:
|
||||||
|
mult3 = 1 + 1 / (2 * r)
|
||||||
|
mult4 = 1 / (2 * r)
|
||||||
|
return mult1, mult2, mult3, mult4
|
||||||
|
else:
|
||||||
|
return mult1, mult2
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: torch.Tensor,
|
||||||
|
old_pred_original_sample: torch.Tensor,
|
||||||
|
timestep: int,
|
||||||
|
timestep_back: int,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
eta: float = 0.0,
|
||||||
|
use_clipped_model_output: bool = False,
|
||||||
|
generator=None,
|
||||||
|
variance_noise: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = False,
|
||||||
|
) -> Union[CogVideoXDPMSchedulerOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||||
|
process from the learned model outputs (most often the predicted noise).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output (`torch.Tensor`):
|
||||||
|
The direct output from learned diffusion model.
|
||||||
|
timestep (`float`):
|
||||||
|
The current discrete timestep in the diffusion chain.
|
||||||
|
sample (`torch.Tensor`):
|
||||||
|
A current instance of a sample created by the diffusion process.
|
||||||
|
eta (`float`):
|
||||||
|
The weight of noise for added noise in diffusion step.
|
||||||
|
use_clipped_model_output (`bool`, defaults to `False`):
|
||||||
|
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
||||||
|
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
||||||
|
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
||||||
|
`use_clipped_model_output` has no effect.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A random number generator.
|
||||||
|
variance_noise (`torch.Tensor`):
|
||||||
|
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||||
|
itself. Useful for methods such as [`CycleDiffusion`].
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~schedulers.scheduling_dpm_cogvideox.CogVideoXDPMDPMSchedulerOutput`] or
|
||||||
|
`tuple`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~schedulers.scheduling_dpm_cogvideox.CogVideoXDPMSchedulerOutput`] or `tuple`:
|
||||||
|
If return_dict is `True`, [`~schedulers.scheduling_dpm_cogvideox.CogVideoXDPMSchedulerOutput`] is
|
||||||
|
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.num_inference_steps is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||||
|
)
|
||||||
|
|
||||||
|
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
# Ideally, read DDIM paper in-detail understanding
|
||||||
|
|
||||||
|
# Notation (<variable name> -> <name in paper>
|
||||||
|
# - pred_noise_t -> e_theta(x_t, t)
|
||||||
|
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
||||||
|
# - std_dev_t -> sigma_t
|
||||||
|
# - eta -> η
|
||||||
|
# - pred_sample_direction -> "direction pointing to x_t"
|
||||||
|
# - pred_prev_sample -> "x_t-1"
|
||||||
|
|
||||||
|
# 1. get previous step value (=t-1)
|
||||||
|
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||||
|
|
||||||
|
# 2. compute alphas, betas
|
||||||
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||||
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||||
|
alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None
|
||||||
|
|
||||||
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
|
||||||
|
# 3. compute predicted original sample from predicted noise also called
|
||||||
|
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
|
||||||
|
if self.config.prediction_type == "epsilon":
|
||||||
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
|
# pred_epsilon = model_output
|
||||||
|
elif self.config.prediction_type == "sample":
|
||||||
|
pred_original_sample = model_output
|
||||||
|
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||||
|
elif self.config.prediction_type == "v_prediction":
|
||||||
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||||
|
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||||
|
" `v_prediction`"
|
||||||
|
)
|
||||||
|
|
||||||
|
h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)
|
||||||
|
mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back))
|
||||||
|
mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5
|
||||||
|
|
||||||
|
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
||||||
|
prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise
|
||||||
|
|
||||||
|
if old_pred_original_sample is None or prev_timestep < 0:
|
||||||
|
# Save a network evaluation if all noise levels are 0 or on the first step
|
||||||
|
return prev_sample, pred_original_sample
|
||||||
|
else:
|
||||||
|
denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
|
||||||
|
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
||||||
|
x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
|
||||||
|
|
||||||
|
prev_sample = x_advanced
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (prev_sample, pred_original_sample)
|
||||||
|
|
||||||
|
return CogVideoXDPMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||||
|
def add_noise(
|
||||||
|
self,
|
||||||
|
original_samples: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
timesteps: torch.IntTensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||||
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
||||||
|
# for the subsequent add_noise calls
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
||||||
|
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||||
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
|
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||||
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||||
|
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||||
|
return noisy_samples
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||||||
|
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
||||||
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
||||||
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
||||||
|
timesteps = timesteps.to(sample.device)
|
||||||
|
|
||||||
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||||
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
|
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||||||
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||||
|
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||||||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||||
|
return velocity
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.config.num_train_timesteps
|
||||||
@@ -47,6 +47,21 @@ class AutoencoderKL(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKLCogVideoX(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
|
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -92,6 +107,21 @@ class AutoencoderTiny(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXTransformer3DModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class ConsistencyDecoderVAE(metaclass=DummyObject):
|
class ConsistencyDecoderVAE(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -975,6 +1005,36 @@ class CMStochasticIterativeScheduler(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXDDIMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXDPMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class DDIMInverseScheduler(metaclass=DummyObject):
|
class DDIMInverseScheduler(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -287,6 +287,21 @@ class CLIPImageProjection(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers"])
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class CycleDiffusionPipeline(metaclass=DummyObject):
|
class CycleDiffusionPipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
|||||||
0
tests/pipelines/cogvideox/__init__.py
Normal file
0
tests/pipelines/cogvideox/__init__.py
Normal file
289
tests/pipelines/cogvideox/test_cogvideox.py
Normal file
289
tests/pipelines/cogvideox/test_cogvideox.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import inspect
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, T5EncoderModel
|
||||||
|
|
||||||
|
from diffusers import AutoencoderKL, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
|
||||||
|
from diffusers.utils.testing_utils import (
|
||||||
|
enable_full_determinism,
|
||||||
|
numpy_cosine_similarity_distance,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||||
|
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||||
|
|
||||||
|
|
||||||
|
enable_full_determinism()
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
pipeline_class = CogVideoXPipeline
|
||||||
|
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||||
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||||
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
|
||||||
|
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||||
|
|
||||||
|
def get_dummy_components(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
transformer = CogVideoXTransformer3DModel(
|
||||||
|
sample_size=8,
|
||||||
|
num_layers=1,
|
||||||
|
patch_size=2,
|
||||||
|
attention_head_dim=8,
|
||||||
|
num_attention_heads=3,
|
||||||
|
caption_channels=32,
|
||||||
|
in_channels=4,
|
||||||
|
cross_attention_dim=24,
|
||||||
|
out_channels=8,
|
||||||
|
attention_bias=True,
|
||||||
|
activation_fn="gelu-approximate",
|
||||||
|
num_embeds_ada_norm=1000,
|
||||||
|
norm_type="ada_norm_single",
|
||||||
|
norm_elementwise_affine=False,
|
||||||
|
norm_eps=1e-6,
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vae = AutoencoderKL()
|
||||||
|
|
||||||
|
scheduler = DDIMScheduler()
|
||||||
|
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
|
|
||||||
|
components = {
|
||||||
|
"transformer": transformer.eval(),
|
||||||
|
"vae": vae.eval(),
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"text_encoder": text_encoder.eval(),
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
}
|
||||||
|
return components
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
|
if str(device).startswith("mps"):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
"negative_prompt": "low quality",
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 5.0,
|
||||||
|
"height": 8,
|
||||||
|
"width": 8,
|
||||||
|
"video_length": 1,
|
||||||
|
"output_type": "pt",
|
||||||
|
"clean_caption": False,
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_inference(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
video = pipe(**inputs).frames
|
||||||
|
generated_video = video[0]
|
||||||
|
|
||||||
|
self.assertEqual(generated_video.shape, (1, 3, 8, 8))
|
||||||
|
expected_video = torch.randn(1, 3, 8, 8)
|
||||||
|
max_diff = np.abs(generated_video - expected_video).max()
|
||||||
|
self.assertLessEqual(max_diff, 1e10)
|
||||||
|
|
||||||
|
def test_callback_inputs(self):
|
||||||
|
sig = inspect.signature(self.pipeline_class.__call__)
|
||||||
|
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||||
|
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||||
|
|
||||||
|
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||||
|
return
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
self.assertTrue(
|
||||||
|
hasattr(pipe, "_callback_tensor_inputs"),
|
||||||
|
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||||
|
)
|
||||||
|
|
||||||
|
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||||
|
# iterate over callback args
|
||||||
|
for tensor_name, tensor_value in callback_kwargs.items():
|
||||||
|
# check that we're only passing in allowed tensor inputs
|
||||||
|
assert tensor_name in pipe._callback_tensor_inputs
|
||||||
|
|
||||||
|
return callback_kwargs
|
||||||
|
|
||||||
|
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||||
|
for tensor_name in pipe._callback_tensor_inputs:
|
||||||
|
assert tensor_name in callback_kwargs
|
||||||
|
|
||||||
|
# iterate over callback args
|
||||||
|
for tensor_name, tensor_value in callback_kwargs.items():
|
||||||
|
# check that we're only passing in allowed tensor inputs
|
||||||
|
assert tensor_name in pipe._callback_tensor_inputs
|
||||||
|
|
||||||
|
return callback_kwargs
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
# Test passing in a subset
|
||||||
|
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||||
|
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||||
|
output = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
# Test passing in a everything
|
||||||
|
inputs["callback_on_step_end"] = callback_inputs_all
|
||||||
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||||
|
output = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||||
|
is_last = i == (pipe.num_timesteps - 1)
|
||||||
|
if is_last:
|
||||||
|
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||||
|
return callback_kwargs
|
||||||
|
|
||||||
|
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||||
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||||
|
output = pipe(**inputs)[0]
|
||||||
|
assert output.abs().sum() < 1e10
|
||||||
|
|
||||||
|
def test_inference_batch_single_identical(self):
|
||||||
|
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
|
||||||
|
|
||||||
|
def test_attention_slicing_forward_pass(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_save_load_optional_components(self):
|
||||||
|
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||||
|
return
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
|
||||||
|
for component in pipe.components.values():
|
||||||
|
if hasattr(component, "set_default_attn_processor"):
|
||||||
|
component.set_default_attn_processor()
|
||||||
|
pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
prompt = inputs["prompt"]
|
||||||
|
generator = inputs["generator"]
|
||||||
|
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
) = pipe.encode_prompt(prompt)
|
||||||
|
|
||||||
|
# inputs with prompt converted to embeddings
|
||||||
|
inputs = {
|
||||||
|
"prompt_embeds": prompt_embeds,
|
||||||
|
"negative_prompt": None,
|
||||||
|
"negative_prompt_embeds": negative_prompt_embeds,
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 5.0,
|
||||||
|
"height": 8,
|
||||||
|
"width": 8,
|
||||||
|
"video_length": 1,
|
||||||
|
"mask_feature": False,
|
||||||
|
"output_type": "pt",
|
||||||
|
"clean_caption": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# set all optional components to None
|
||||||
|
for optional_component in pipe._optional_components:
|
||||||
|
setattr(pipe, optional_component, None)
|
||||||
|
|
||||||
|
output = pipe(**inputs)[0]
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||||
|
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||||
|
pipe_loaded.to(torch_device)
|
||||||
|
|
||||||
|
for component in pipe_loaded.components.values():
|
||||||
|
if hasattr(component, "set_default_attn_processor"):
|
||||||
|
component.set_default_attn_processor()
|
||||||
|
|
||||||
|
pipe_loaded.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
for optional_component in pipe._optional_components:
|
||||||
|
self.assertTrue(
|
||||||
|
getattr(pipe_loaded, optional_component) is None,
|
||||||
|
f"`{optional_component}` did not stay set to None after loading.",
|
||||||
|
)
|
||||||
|
|
||||||
|
output_loaded = pipe_loaded(**inputs)[0]
|
||||||
|
|
||||||
|
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||||
|
self.assertLess(max_diff, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
class CogVideoXPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
prompt = "A painting of a squirrel eating a burger."
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def test_cogvideox(self):
|
||||||
|
generator = torch.Generator("cpu").manual_seed(0)
|
||||||
|
|
||||||
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/cogvideox-2b", torch_dtype=torch.float16)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
prompt = self.prompt
|
||||||
|
|
||||||
|
videos = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
generator=generator,
|
||||||
|
num_inference_steps=2,
|
||||||
|
clean_caption=False,
|
||||||
|
).frames
|
||||||
|
|
||||||
|
video = videos[0]
|
||||||
|
expected_video = torch.randn(1, 512, 512, 3).numpy()
|
||||||
|
|
||||||
|
max_diff = numpy_cosine_similarity_distance(video.fCogVideoXn(), expected_video)
|
||||||
|
assert max_diff < 1e-3, f"Max diff is too high. got {video.fCogVideoXn()}"
|
||||||
Reference in New Issue
Block a user