mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 14:04:37 +08:00
Compare commits
12 Commits
fix-bnb-te
...
add-stocha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff1012f8cf | ||
|
|
25bc77d8f8 | ||
|
|
9c35a89921 | ||
|
|
32d9aef997 | ||
|
|
9edc5beddc | ||
|
|
f87956e9cf | ||
|
|
690adb5bd9 | ||
|
|
5873377a66 | ||
|
|
5a2e0f715c | ||
|
|
ef47726e2d | ||
|
|
0021bfa1e1 | ||
|
|
bbd0c161b5 |
@@ -133,6 +133,60 @@ output = pipe(
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### First and Last Frame Interpolation
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
|
||||
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
|
||||
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
|
||||
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
def center_crop_resize(image, height, width):
|
||||
# Calculate resize ratio to match first frame dimensions
|
||||
resize_ratio = max(width / image.width, height / image.height)
|
||||
|
||||
# Resize the image
|
||||
width = round(image.width * resize_ratio)
|
||||
height = round(image.height * resize_ratio)
|
||||
size = [width, height]
|
||||
image = TF.center_crop(image, size)
|
||||
|
||||
return image, height, width
|
||||
|
||||
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
|
||||
if last_frame.size != first_frame.size:
|
||||
last_frame, _, _ = center_crop_resize(last_frame, height, width)
|
||||
|
||||
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
|
||||
|
||||
output = pipe(
|
||||
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Video to Video Generation
|
||||
|
||||
```python
|
||||
|
||||
@@ -33,7 +33,6 @@ from diffusers import DiffusionPipeline
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import (
|
||||
FromSingleFileMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
@@ -300,7 +299,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionLoraLoaderMixin
|
||||
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
|
||||
|
||||
@@ -39,6 +39,24 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# for the FLF2V model
|
||||
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
|
||||
# Add attention component mappings
|
||||
"self_attn.q": "attn1.to_q",
|
||||
"self_attn.k": "attn1.to_k",
|
||||
"self_attn.v": "attn1.to_v",
|
||||
"self_attn.o": "attn1.to_out.0",
|
||||
"self_attn.norm_q": "attn1.norm_q",
|
||||
"self_attn.norm_k": "attn1.norm_k",
|
||||
"cross_attn.q": "attn2.to_q",
|
||||
"cross_attn.k": "attn2.to_k",
|
||||
"cross_attn.v": "attn2.to_v",
|
||||
"cross_attn.o": "attn2.to_out.0",
|
||||
"cross_attn.norm_q": "attn2.norm_q",
|
||||
"cross_attn.norm_k": "attn2.norm_k",
|
||||
"attn2.to_k_img": "attn2.add_k_proj",
|
||||
"attn2.to_v_img": "attn2.add_v_proj",
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
@@ -135,6 +153,28 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
elif model_type == "Wan-FLF2V-14B-720P":
|
||||
config = {
|
||||
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"rope_max_seq_len": 1024,
|
||||
"pos_embed_seq_len": 257 * 2,
|
||||
},
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
@@ -393,11 +433,12 @@ if __name__ == "__main__":
|
||||
vae = convert_vae()
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
|
||||
)
|
||||
|
||||
if "I2V" in args.model_type:
|
||||
if "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
@@ -918,5 +918,5 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output, hidden_states_masks)
|
||||
return Transformer2DModelOutput(sample=output, mask=hidden_states_masks)
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -49,8 +49,10 @@ class WanAttnProcessor2_0:
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :257]
|
||||
encoder_hidden_states = encoder_hidden_states[:, 257:]
|
||||
# 512 is the context length of the text encoder, hardcoded for now
|
||||
image_context_length = encoder_hidden_states.shape[1] - 512
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
||||
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
@@ -108,14 +110,23 @@ class WanAttnProcessor2_0:
|
||||
|
||||
|
||||
class WanImageEmbedding(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = FP32LayerNorm(in_features)
|
||||
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
||||
self.norm2 = FP32LayerNorm(out_features)
|
||||
if pos_embed_seq_len is not None:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is not None:
|
||||
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
||||
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
|
||||
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
||||
|
||||
hidden_states = self.norm1(encoder_hidden_states_image)
|
||||
hidden_states = self.ff(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
@@ -130,6 +141,7 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
image_embed_dim: Optional[int] = None,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -141,7 +153,7 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
|
||||
self.image_embedder = None
|
||||
if image_embed_dim is not None:
|
||||
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
|
||||
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -350,6 +362,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
image_dim: Optional[int] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
rope_max_seq_len: int = 1024,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -368,6 +381,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
image_embed_dim=image_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
|
||||
@@ -404,6 +404,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
||||
|
||||
if is_loaded_in_8bit_bnb:
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and (
|
||||
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
||||
or hasattr(module._hf_hook, "hooks")
|
||||
|
||||
@@ -380,6 +380,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
last_image: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
@@ -398,9 +399,16 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
image = image.unsqueeze(2)
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||
)
|
||||
if last_image is None:
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||
)
|
||||
else:
|
||||
last_image = last_image.unsqueeze(2)
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
|
||||
dim=2,
|
||||
)
|
||||
video_condition = video_condition.to(device=device, dtype=dtype)
|
||||
|
||||
latents_mean = (
|
||||
@@ -424,7 +432,11 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
|
||||
if last_image is None:
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
else:
|
||||
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
@@ -476,6 +488,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
last_image: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -620,7 +633,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
if image_embeds is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
if last_image is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
else:
|
||||
image_embeds = self.encode_image([image, last_image], device)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
|
||||
@@ -631,6 +647,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.z_dim
|
||||
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
|
||||
if last_image is not None:
|
||||
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
|
||||
device, dtype=torch.float32
|
||||
)
|
||||
latents, condition = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
@@ -642,6 +662,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
last_image,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
|
||||
@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
|
||||
time_shift_type (`str`, defaults to "exponential"):
|
||||
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
|
||||
stochastic_sampling (`bool`, defaults to False):
|
||||
Whether to use stochastic sampling.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -101,6 +103,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
time_shift_type: str = "exponential",
|
||||
stochastic_sampling: bool = False,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -437,13 +440,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
||||
lower_sigmas = lower_mask * sigmas
|
||||
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
||||
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
||||
|
||||
current_sigma = per_token_sigmas[..., None]
|
||||
next_sigma = lower_sigmas[..., None]
|
||||
dt = current_sigma - next_sigma
|
||||
else:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
sigma_idx = self.step_index
|
||||
sigma = self.sigmas[sigma_idx]
|
||||
sigma_next = self.sigmas[sigma_idx + 1]
|
||||
|
||||
current_sigma = sigma
|
||||
next_sigma = sigma_next
|
||||
dt = sigma_next - sigma
|
||||
|
||||
prev_sample = sample + dt * model_output
|
||||
if self.config.stochastic_sampling:
|
||||
x0 = sample - current_sigma * model_output
|
||||
noise = torch.randn_like(sample)
|
||||
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
|
||||
else:
|
||||
prev_sample = sample + dt * model_output
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
@@ -160,3 +160,90 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
|
||||
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# TODO: impl FlowDPMSolverMultistepScheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
image_dim=4,
|
||||
pos_embed_seq_len=2 * (4 * 4 + 1),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
hidden_size=4,
|
||||
projection_dim=4,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
image_size=4,
|
||||
intermediate_size=16,
|
||||
patch_size=1,
|
||||
)
|
||||
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_processor = CLIPImageProcessor(crop_size=4, size=4)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image_height = 16
|
||||
image_width = 16
|
||||
image = Image.new("RGB", (image_width, image_height))
|
||||
last_image = Image.new("RGB", (image_width, image_height))
|
||||
inputs = {
|
||||
"image": image,
|
||||
"last_image": last_image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative",
|
||||
"height": image_height,
|
||||
"width": image_width,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
Reference in New Issue
Block a user