mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-02 09:51:06 +08:00
Compare commits
1 Commits
vid-pipe-o
...
fix_num_pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a99b0eb517 |
@@ -33,9 +33,6 @@ model = AutoencoderKL.from_single_file(url)
|
||||
## AutoencoderKL
|
||||
|
||||
[[autodoc]] AutoencoderKL
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
|
||||
@@ -1279,7 +1279,7 @@ def main(args):
|
||||
for name, param in text_encoder_one.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_one.append(param)
|
||||
else:
|
||||
@@ -1288,7 +1288,7 @@ def main(args):
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_two.append(param)
|
||||
else:
|
||||
@@ -1725,19 +1725,19 @@ def main(args):
|
||||
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
|
||||
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
|
||||
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
|
||||
# flag used for textual inversion
|
||||
pivoted = False
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
# if performing any kind of optimization of text_encoder params
|
||||
if args.train_text_encoder or args.train_text_encoder_ti:
|
||||
if epoch == num_train_epochs_text_encoder:
|
||||
print("PIVOT HALFWAY", epoch)
|
||||
# stopping optimization of text_encoder params
|
||||
# this flag is used to reset the optimizer to optimize only on unet params
|
||||
pivoted = True
|
||||
# re setting the optimizer to optimize only on unet params
|
||||
optimizer.param_groups[1]["lr"] = 0.0
|
||||
optimizer.param_groups[2]["lr"] = 0.0
|
||||
|
||||
else:
|
||||
# still optimizing the text encoder
|
||||
# still optimizng the text encoder
|
||||
text_encoder_one.train()
|
||||
text_encoder_two.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
@@ -1747,12 +1747,6 @@ def main(args):
|
||||
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if pivoted:
|
||||
# stopping optimization of text_encoder params
|
||||
# re setting the optimizer to optimize only on unet params
|
||||
optimizer.param_groups[1]["lr"] = 0.0
|
||||
optimizer.param_groups[2]["lr"] = 0.0
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
prompts = batch["prompts"]
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
@@ -1891,7 +1885,8 @@ def main(args):
|
||||
|
||||
# every step, we reset the embeddings to the original embeddings.
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.retract_embeddings()
|
||||
for idx, text_encoder in enumerate(text_encoders):
|
||||
embedding_handler.retract_embeddings()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -67,7 +67,10 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -76,15 +79,6 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -811,7 +805,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
@@ -40,8 +40,10 @@ def _append_dims(x, target_dims):
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -51,13 +53,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
return np.stack(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -59,26 +58,22 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
@@ -127,7 +122,6 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -723,7 +717,11 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -20,7 +20,6 @@ import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -94,26 +93,22 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
|
||||
|
||||
def preprocess_video(video):
|
||||
@@ -203,7 +198,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -818,11 +812,12 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -262,7 +262,7 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
|
||||
sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results"
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
|
||||
@@ -29,7 +29,6 @@ from diffusers.utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
@@ -142,11 +141,10 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "np"
|
||||
frames = sd_pipe(**inputs).frames
|
||||
image_slice = frames[0][-3:, -3:, -1]
|
||||
|
||||
image_slice = frames[0][0][-3:, -3:, -1]
|
||||
|
||||
assert frames[0][0].shape == (32, 32, 3)
|
||||
expected_slice = np.array([0.7537, 0.1752, 0.6157, 0.5508, 0.4240, 0.4110, 0.4838, 0.5648, 0.5094])
|
||||
assert frames[0].shape == (32, 32, 3)
|
||||
expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -185,7 +183,7 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class TextToVideoSDPipelineSlowTests(unittest.TestCase):
|
||||
def test_two_step_model(self):
|
||||
expected_video = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy"
|
||||
)
|
||||
|
||||
pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
|
||||
@@ -194,8 +192,10 @@ class TextToVideoSDPipelineSlowTests(unittest.TestCase):
|
||||
prompt = "Spiderman is surfing"
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
|
||||
assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4
|
||||
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
|
||||
video = video_frames.cpu().numpy()
|
||||
|
||||
assert np.abs(expected_video - video).mean() < 5e-2
|
||||
|
||||
def test_two_step_model_with_freeu(self):
|
||||
expected_video = []
|
||||
@@ -207,9 +207,10 @@ class TextToVideoSDPipelineSlowTests(unittest.TestCase):
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
|
||||
video = video_frames[0, 0, -3:, -3:, -1].flatten()
|
||||
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
|
||||
video = video_frames.cpu().numpy()
|
||||
video = video[0, 0, -3:, -3:, -1].flatten()
|
||||
|
||||
expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475]
|
||||
expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177]
|
||||
|
||||
assert np.abs(expected_video - video).mean() < 5e-2
|
||||
|
||||
@@ -157,10 +157,10 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "np"
|
||||
frames = sd_pipe(**inputs).frames
|
||||
image_slice = frames[0][0][-3:, -3:, -1]
|
||||
image_slice = frames[0][-3:, -3:, -1]
|
||||
|
||||
assert frames[0][0].shape == (32, 32, 3)
|
||||
expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185])
|
||||
assert frames[0].shape == (32, 32, 3)
|
||||
expected_slice = np.array([162.0, 136.0, 132.0, 140.0, 139.0, 137.0, 169.0, 134.0, 132.0])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -214,11 +214,9 @@ class VideoToVideoSDPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames
|
||||
video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="pt").frames
|
||||
|
||||
expected_array = np.array(
|
||||
[0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422]
|
||||
)
|
||||
output_array = video_frames[0, 0, :3, :3, 0].flatten()
|
||||
assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3
|
||||
expected_array = np.array([-0.9770508, -0.8027344, -0.62646484, -0.8334961, -0.7573242])
|
||||
output_array = video_frames.cpu().numpy()[0, 0, 0, 0, -5:]
|
||||
|
||||
assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user