mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 13:04:15 +08:00
Compare commits
6 Commits
tests-memo
...
str-to-boo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
efbbbc38e4 | ||
|
|
9a34953823 | ||
|
|
e29f16cfaa | ||
|
|
f7dfcfd971 | ||
|
|
3c67864c5a | ||
|
|
363699044e |
121
examples/research_projects/diffusion_orpo/README.md
Normal file
121
examples/research_projects/diffusion_orpo/README.md
Normal file
@@ -0,0 +1,121 @@
|
||||
This project is an attempt to check if it's possible to apply to [ORPO](https://arxiv.org/abs/2403.07691) on a text-conditioned diffusion model to align it on preference data WITHOUT a reference model. The implementation is based on https://github.com/huggingface/trl/pull/1435/.
|
||||
|
||||
> [!WARNING]
|
||||
> We assume that MSE in the diffusion formulation approximates the log-probs as required by ORPO (hat-tip to [@kashif](https://github.com/kashif) for the idea). So, please consider this to be extremely experimental.
|
||||
|
||||
## Training
|
||||
|
||||
Here's training command you can use on a 40GB A100 to validate things on a [small preference
|
||||
dataset](https://hf.co/datasets/kashif/pickascore):
|
||||
|
||||
```bash
|
||||
accelerate launch train_diffusion_orpo_sdxl_lora.py \
|
||||
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir="diffusion-sdxl-orpo" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=kashif/pickascore \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=2000 \
|
||||
--checkpointing_steps=500 \
|
||||
--run_validation --validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
We also provide a simple script to scale up the training on the [yuvalkirstain/pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset:
|
||||
|
||||
```bash
|
||||
accelerate launch --multi_gpu train_diffusion_orpo_sdxl_lora_wds.py \
|
||||
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--dataset_path="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -" \
|
||||
--output_dir="diffusion-sdxl-orpo-wds" \
|
||||
--mixed_precision="fp16" \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--dataloader_num_workers=8 \
|
||||
--learning_rate=3e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=50000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--run_validation --validation_steps=500 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
We tested the above on a node of 8 H100s but it should also work on A100s. It requires the `webdataset` library for faster dataloading. Note that we kept the dataset shards on an S3 bucket but it should be also possible to have them stored locally.
|
||||
|
||||
You can use the code below to convert the original dataset into `webdataset` shards:
|
||||
|
||||
```python
|
||||
import os
|
||||
import io
|
||||
import ray
|
||||
import webdataset as wds
|
||||
from datasets import Dataset
|
||||
from PIL import Image
|
||||
|
||||
ray.init(num_cpus=8)
|
||||
|
||||
|
||||
def convert_to_image(im_bytes):
|
||||
return Image.open(io.BytesIO(im_bytes)).convert("RGB")
|
||||
|
||||
def main():
|
||||
dataset_path = "/pickapic_v2/data"
|
||||
wds_shards_path = "/pickapic_v2_webdataset"
|
||||
# get all .parquet files in the dataset path
|
||||
dataset_files = [
|
||||
os.path.join(dataset_path, f)
|
||||
for f in os.listdir(dataset_path)
|
||||
if f.endswith(".parquet")
|
||||
]
|
||||
|
||||
@ray.remote
|
||||
def create_shard(path):
|
||||
# get basename of the file
|
||||
basename = os.path.basename(path)
|
||||
# get the shard number data-00123-of-01034.parquet -> 00123
|
||||
shard_num = basename.split("-")[1]
|
||||
dataset = Dataset.from_parquet(path)
|
||||
# create a webdataset shard
|
||||
shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar"))
|
||||
|
||||
for i, example in enumerate(dataset):
|
||||
wds_example = {
|
||||
"__key__": str(i),
|
||||
"original_prompt.txt": example["caption"],
|
||||
"jpg_0.jpg": convert_to_image(example["jpg_0"]),
|
||||
"jpg_1.jpg": convert_to_image(example["jpg_1"]),
|
||||
"label_0.txt": str(example["label_0"]),
|
||||
"label_1.txt": str(example["label_1"])
|
||||
}
|
||||
shard.write(wds_example)
|
||||
shard.close()
|
||||
|
||||
futures = [create_shard.remote(path) for path in dataset_files]
|
||||
ray.get(futures)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Refer to [sayakpaul/diffusion-sdxl-orpo](https://huggingface.co/sayakpaul/diffusion-sdxl-orpo) for an experimental checkpoint.
|
||||
@@ -0,0 +1,7 @@
|
||||
datasets
|
||||
accelerate
|
||||
transformers
|
||||
torchvision
|
||||
wandb
|
||||
peft
|
||||
webdataset
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
5
setup.py
5
setup.py
@@ -81,9 +81,8 @@ To create the package for PyPI.
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from distutils.core import Command
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools import Command, find_packages, setup
|
||||
|
||||
|
||||
# IMPORTANT:
|
||||
@@ -163,7 +162,7 @@ def deps_list(*pkgs):
|
||||
|
||||
class DepsTableUpdateCommand(Command):
|
||||
"""
|
||||
A custom distutils command that updates the dependency table.
|
||||
A custom command that updates the dependency table.
|
||||
usage: python setup.py deps_table_update
|
||||
"""
|
||||
|
||||
|
||||
@@ -792,7 +792,7 @@ class AnimateDiffPipeline(
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
@@ -944,7 +944,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
@@ -146,39 +146,40 @@ class FreeInitMixin:
|
||||
):
|
||||
if free_init_iteration == 0:
|
||||
self._free_init_initial_noise = latents.detach().clone()
|
||||
return latents, self.scheduler.timesteps
|
||||
else:
|
||||
latent_shape = latents.shape
|
||||
|
||||
latent_shape = latents.shape
|
||||
free_init_filter_shape = (1, *latent_shape[1:])
|
||||
free_init_freq_filter = self._get_free_init_freq_filter(
|
||||
shape=free_init_filter_shape,
|
||||
device=device,
|
||||
filter_type=self._free_init_method,
|
||||
order=self._free_init_order,
|
||||
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
|
||||
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
|
||||
)
|
||||
|
||||
free_init_filter_shape = (1, *latent_shape[1:])
|
||||
free_init_freq_filter = self._get_free_init_freq_filter(
|
||||
shape=free_init_filter_shape,
|
||||
device=device,
|
||||
filter_type=self._free_init_method,
|
||||
order=self._free_init_order,
|
||||
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
|
||||
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
|
||||
)
|
||||
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
|
||||
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
|
||||
|
||||
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
|
||||
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
|
||||
z_t = self.scheduler.add_noise(
|
||||
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
|
||||
).to(dtype=torch.float32)
|
||||
|
||||
z_t = self.scheduler.add_noise(
|
||||
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
|
||||
).to(dtype=torch.float32)
|
||||
|
||||
z_rand = randn_tensor(
|
||||
shape=latent_shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
|
||||
latents = latents.to(dtype)
|
||||
z_rand = randn_tensor(
|
||||
shape=latent_shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
|
||||
latents = latents.to(dtype)
|
||||
|
||||
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
|
||||
if self._free_init_use_fast_sampling:
|
||||
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
|
||||
num_inference_steps = max(
|
||||
1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
|
||||
)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
return latents, self.scheduler.timesteps
|
||||
|
||||
@@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.fft as fft
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
@@ -130,71 +128,6 @@ def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_sca
|
||||
return coef
|
||||
|
||||
|
||||
def _get_freeinit_freq_filter(
|
||||
shape: Tuple[int, ...],
|
||||
device: Union[str, torch.dtype],
|
||||
filter_type: str,
|
||||
order: float,
|
||||
spatial_stop_frequency: float,
|
||||
temporal_stop_frequency: float,
|
||||
) -> torch.Tensor:
|
||||
r"""Returns the FreeInit filter based on filter type and other input conditions."""
|
||||
|
||||
time, height, width = shape[-3], shape[-2], shape[-1]
|
||||
mask = torch.zeros(shape)
|
||||
|
||||
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
|
||||
return mask
|
||||
|
||||
if filter_type == "butterworth":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
|
||||
elif filter_type == "gaussian":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
|
||||
elif filter_type == "ideal":
|
||||
|
||||
def retrieve_mask(x):
|
||||
return 1 if x <= spatial_stop_frequency * 2 else 0
|
||||
else:
|
||||
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
|
||||
|
||||
for t in range(time):
|
||||
for h in range(height):
|
||||
for w in range(width):
|
||||
d_square = (
|
||||
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
|
||||
+ (2 * h / height - 1) ** 2
|
||||
+ (2 * w / width - 1) ** 2
|
||||
)
|
||||
mask[..., t, h, w] = retrieve_mask(d_square)
|
||||
|
||||
return mask.to(device)
|
||||
|
||||
|
||||
def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
|
||||
r"""Noise reinitialization."""
|
||||
# FFT
|
||||
x_freq = fft.fftn(x, dim=(-3, -2, -1))
|
||||
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
|
||||
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
|
||||
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
|
||||
|
||||
# frequency mix
|
||||
HPF = 1 - LPF
|
||||
x_freq_low = x_freq * LPF
|
||||
noise_freq_high = noise_freq * HPF
|
||||
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
|
||||
|
||||
# IFFT
|
||||
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
|
||||
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
|
||||
|
||||
return x_mixed
|
||||
|
||||
|
||||
@dataclass
|
||||
class PIAPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
@@ -202,9 +135,9 @@ class PIAPipelineOutput(BaseOutput):
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
|
||||
NumPy array of shape `(batch_size, num_frames, channels, height, width,
|
||||
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
|
||||
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
|
||||
NumPy array of shape `(batch_size, num_frames, channels, height, width,
|
||||
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
@@ -788,7 +721,8 @@ class PIAPipeline(
|
||||
The input image to be used for video generation.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
@@ -979,8 +913,10 @@ class PIAPipeline(
|
||||
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
||||
)
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
@@ -59,6 +59,66 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusionPanoramaPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin
|
||||
):
|
||||
@@ -97,6 +157,7 @@ class StableDiffusionPanoramaPipeline(
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -461,10 +522,23 @@ class StableDiffusionPanoramaPipeline(
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def decode_latents_with_padding(self, latents, padding=8):
|
||||
# Add padding to latents for circular inference
|
||||
# padding is the number of latents to add on each side
|
||||
# it would slightly increase the memory usage, but remove the boundary artifacts
|
||||
def decode_latents_with_padding(self, latents: torch.Tensor, padding: int = 8) -> torch.Tensor:
|
||||
"""
|
||||
Decode the given latents with padding for circular inference.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor): The input latents to decode.
|
||||
padding (int, optional): The number of latents to add on each side for padding. Defaults to 8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The decoded image with padding removed.
|
||||
|
||||
Notes:
|
||||
- The padding is added to remove boundary artifacts and improve the output quality.
|
||||
- This would slightly increase the memory usage.
|
||||
- The padding pixels are then removed from the decoded image.
|
||||
|
||||
"""
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
latents_left = latents[..., :padding]
|
||||
latents_right = latents[..., -padding:]
|
||||
@@ -580,9 +654,62 @@ class StableDiffusionPanoramaPipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False):
|
||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||
# if panorama's height/width < window_size, num_blocks of height/width should return 1
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
w (`torch.Tensor`):
|
||||
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
Dimension of the embeddings to generate.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||
Data type of the generated embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
def get_views(
|
||||
self,
|
||||
panorama_height: int,
|
||||
panorama_width: int,
|
||||
window_size: int = 64,
|
||||
stride: int = 8,
|
||||
circular_padding: bool = False,
|
||||
) -> List[Tuple[int, int, int, int]]:
|
||||
"""
|
||||
Generates a list of views based on the given parameters.
|
||||
Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113).
|
||||
If panorama's height/width < window_size, num_blocks of height/width should return 1.
|
||||
|
||||
Args:
|
||||
panorama_height (int): The height of the panorama.
|
||||
panorama_width (int): The width of the panorama.
|
||||
window_size (int, optional): The size of the window. Defaults to 64.
|
||||
stride (int, optional): The stride value. Defaults to 8.
|
||||
circular_padding (bool, optional): Whether to apply circular padding. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains
|
||||
four integers representing the start and end coordinates of the window in the panorama.
|
||||
|
||||
"""
|
||||
panorama_height /= 8
|
||||
panorama_width /= 8
|
||||
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
|
||||
@@ -600,6 +727,34 @@ class StableDiffusionPanoramaPipeline(
|
||||
views.append((h_start, h_end, w_start, w_end))
|
||||
return views
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -608,6 +763,7 @@ class StableDiffusionPanoramaPipeline(
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 2048,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
view_batch_size: int = 1,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -621,11 +777,13 @@ class StableDiffusionPanoramaPipeline(
|
||||
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
circular_padding: bool = False,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs: Any,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -641,6 +799,9 @@ class StableDiffusionPanoramaPipeline(
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
The timesteps at which to generate the images. If not specified, then the default
|
||||
timestep spacing strategy of the scheduler is used.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
@@ -680,16 +841,12 @@ class StableDiffusionPanoramaPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescaling factor for the guidance embeddings. A value of 0.0 means no rescaling is applied.
|
||||
circular_padding (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, circular padding is applied to ensure there are no stitching artifacts. Circular
|
||||
padding allows the model to seamlessly generate a transition from the rightmost part of the image to
|
||||
@@ -697,6 +854,15 @@ class StableDiffusionPanoramaPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
@@ -706,6 +872,22 @@ class StableDiffusionPanoramaPipeline(
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
@@ -721,8 +903,15 @@ class StableDiffusionPanoramaPipeline(
|
||||
negative_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -768,8 +957,7 @@ class StableDiffusionPanoramaPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
@@ -802,12 +990,23 @@ class StableDiffusionPanoramaPipeline(
|
||||
else None
|
||||
)
|
||||
|
||||
# 7.2 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
# 8. Denoising loop
|
||||
# Each denoising step also includes refinement of the latents with respect to the
|
||||
# views.
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
count.zero_()
|
||||
value.zero_()
|
||||
|
||||
@@ -863,6 +1062,7 @@ class StableDiffusionPanoramaPipeline(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds_input,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
@@ -872,6 +1072,12 @@ class StableDiffusionPanoramaPipeline(
|
||||
noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(
|
||||
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_denoised_batch = self.scheduler.step(
|
||||
noise_pred, t, latents_for_view, **extra_step_kwargs
|
||||
@@ -901,6 +1107,16 @@ class StableDiffusionPanoramaPipeline(
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = torch.where(count > 0, value / count, value)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
@@ -908,7 +1124,7 @@ class StableDiffusionPanoramaPipeline(
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
if output_type != "latent":
|
||||
if circular_padding:
|
||||
image = self.decode_latents_with_padding(latents)
|
||||
else:
|
||||
|
||||
@@ -14,7 +14,6 @@ import time
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from contextlib import contextmanager
|
||||
from distutils.util import strtobool
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -142,6 +141,22 @@ def get_tests_dir(append_path=None):
|
||||
return tests_dir
|
||||
|
||||
|
||||
# Taken from the following PR:
|
||||
# https://github.com/huggingface/accelerate/pull/1964
|
||||
def str_to_bool(value) -> int:
|
||||
"""
|
||||
Converts a string representation of truth to `True` (1) or `False` (0).
|
||||
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
|
||||
"""
|
||||
value = value.lower()
|
||||
if value in ("y", "yes", "t", "true", "on", "1"):
|
||||
return 1
|
||||
elif value in ("n", "no", "f", "false", "off", "0"):
|
||||
return 0
|
||||
else:
|
||||
raise ValueError(f"invalid truth value {value}")
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
try:
|
||||
value = os.environ[key]
|
||||
@@ -151,7 +166,7 @@ def parse_flag_from_env(key, default=False):
|
||||
else:
|
||||
# KEY is set, convert it to True or False.
|
||||
try:
|
||||
_value = strtobool(value)
|
||||
_value = str_to_bool(value)
|
||||
except ValueError:
|
||||
# More values are supported, but let's keep the message simple.
|
||||
raise ValueError(f"If set, {key} must be yes or no.")
|
||||
|
||||
@@ -15,12 +15,12 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from distutils.util import strtobool
|
||||
|
||||
import pytest
|
||||
|
||||
from diffusers import __version__
|
||||
from diffusers.utils import deprecate
|
||||
from diffusers.utils.testing_utils import str_to_bool
|
||||
|
||||
|
||||
# Used to test the hub
|
||||
@@ -191,7 +191,7 @@ def parse_flag_from_env(key, default=False):
|
||||
else:
|
||||
# KEY is set, convert it to True or False.
|
||||
try:
|
||||
_value = strtobool(value)
|
||||
_value = str_to_bool(value)
|
||||
except ValueError:
|
||||
# More values are supported, but let's keep the message simple.
|
||||
raise ValueError(f"If set, {key} must be yes or no.")
|
||||
|
||||
@@ -85,6 +85,12 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class IFPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
||||
@@ -94,6 +94,12 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class IFImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -109,6 +115,10 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe(
|
||||
|
||||
@@ -92,6 +92,12 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -107,6 +113,10 @@ class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
|
||||
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
|
||||
|
||||
@@ -92,6 +92,12 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class IFInpaintingPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -105,7 +111,6 @@ class IFInpaintingPipelineSlowTests(unittest.TestCase):
|
||||
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Super resolution test
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
@@ -94,6 +94,12 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class IFInpaintingSuperResolutionPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
||||
@@ -87,6 +87,12 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class IFSuperResolutionPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
||||
@@ -50,7 +50,14 @@ enable_full_determinism()
|
||||
class IPAdapterNightlyTestsMixin(unittest.TestCase):
|
||||
dtype = torch.float16
|
||||
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -313,7 +320,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
feature_extractor=feature_extractor,
|
||||
torch_dtype=self.dtype,
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
@@ -373,7 +380,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
feature_extractor=feature_extractor,
|
||||
torch_dtype=self.dtype,
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
|
||||
inputs = self.get_dummy_inputs(for_image_to_image=True)
|
||||
@@ -442,7 +449,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
feature_extractor=feature_extractor,
|
||||
torch_dtype=self.dtype,
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
|
||||
inputs = self.get_dummy_inputs(for_inpainting=True)
|
||||
@@ -490,7 +497,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=self.dtype,
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus-face_sdxl_vit-h.safetensors"
|
||||
)
|
||||
@@ -518,7 +525,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=self.dtype,
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2
|
||||
)
|
||||
|
||||
@@ -275,6 +275,12 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class KandinskyPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
||||
@@ -299,6 +299,12 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class KandinskyImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
||||
@@ -297,6 +297,12 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class KandinskyInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
||||
@@ -27,7 +27,6 @@ from diffusers.utils.testing_utils import (
|
||||
load_numpy,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
@@ -223,6 +222,12 @@ class KandinskyV22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -238,12 +243,12 @@ class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
|
||||
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
)
|
||||
pipe_prior.to(torch_device)
|
||||
pipe_prior.enable_model_cpu_offload()
|
||||
|
||||
pipeline = KandinskyV22Pipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline = pipeline.enable_model_cpu_offload()
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "red cat, 4k photo"
|
||||
@@ -252,7 +257,7 @@ class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
|
||||
image_emb, zero_image_emb = pipe_prior(
|
||||
prompt,
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
num_inference_steps=3,
|
||||
negative_prompt="",
|
||||
).to_tuple()
|
||||
|
||||
@@ -261,7 +266,7 @@ class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
|
||||
image_embeds=image_emb,
|
||||
negative_image_embeds=zero_image_emb,
|
||||
generator=generator,
|
||||
num_inference_steps=100,
|
||||
num_inference_steps=3,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ from diffusers.utils.testing_utils import (
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
@@ -228,6 +227,12 @@ class KandinskyV22ControlnetPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -250,12 +255,12 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
|
||||
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
)
|
||||
pipe_prior.to(torch_device)
|
||||
pipe_prior.enable_model_cpu_offload()
|
||||
|
||||
pipeline = KandinskyV22ControlnetPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline = pipeline.enable_model_cpu_offload()
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A robot, 4k photo"
|
||||
@@ -264,7 +269,7 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
|
||||
image_emb, zero_image_emb = pipe_prior(
|
||||
prompt,
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
num_inference_steps=2,
|
||||
negative_prompt="",
|
||||
).to_tuple()
|
||||
|
||||
@@ -274,7 +279,7 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
|
||||
negative_image_embeds=zero_image_emb,
|
||||
hint=hint,
|
||||
generator=generator,
|
||||
num_inference_steps=100,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ from diffusers.utils.testing_utils import (
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
@@ -235,6 +234,12 @@ class KandinskyV22ControlnetImg2ImgPipelineFastTests(PipelineTesterMixin, unitte
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -264,12 +269,12 @@ class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
)
|
||||
pipe_prior.to(torch_device)
|
||||
pipe_prior.enable_model_cpu_offload()
|
||||
|
||||
pipeline = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline = pipeline.enable_model_cpu_offload()
|
||||
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -281,6 +286,7 @@ class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
strength=0.85,
|
||||
generator=generator,
|
||||
negative_prompt="",
|
||||
num_inference_steps=5,
|
||||
).to_tuple()
|
||||
|
||||
output = pipeline(
|
||||
@@ -289,7 +295,7 @@ class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
negative_image_embeds=zero_image_emb,
|
||||
hint=hint,
|
||||
generator=generator,
|
||||
num_inference_steps=100,
|
||||
num_inference_steps=5,
|
||||
height=512,
|
||||
width=512,
|
||||
strength=0.5,
|
||||
|
||||
@@ -35,7 +35,6 @@ from diffusers.utils.testing_utils import (
|
||||
load_numpy,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
@@ -240,6 +239,12 @@ class KandinskyV22Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class KandinskyV22Img2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -260,12 +265,12 @@ class KandinskyV22Img2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
)
|
||||
pipe_prior.to(torch_device)
|
||||
pipe_prior.enable_model_cpu_offload()
|
||||
|
||||
pipeline = KandinskyV22Img2ImgPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline = pipeline.enable_model_cpu_offload()
|
||||
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -282,7 +287,7 @@ class KandinskyV22Img2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
image_embeds=image_emb,
|
||||
negative_image_embeds=zero_image_emb,
|
||||
generator=generator,
|
||||
num_inference_steps=100,
|
||||
num_inference_steps=5,
|
||||
height=768,
|
||||
width=768,
|
||||
strength=0.2,
|
||||
|
||||
@@ -293,6 +293,12 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class KandinskyV22InpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -328,7 +334,7 @@ class KandinskyV22InpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
image_emb, zero_image_emb = pipe_prior(
|
||||
prompt,
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
num_inference_steps=2,
|
||||
negative_prompt="",
|
||||
).to_tuple()
|
||||
|
||||
@@ -338,7 +344,7 @@ class KandinskyV22InpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
image_embeds=image_emb,
|
||||
negative_image_embeds=zero_image_emb,
|
||||
generator=generator,
|
||||
num_inference_steps=100,
|
||||
num_inference_steps=2,
|
||||
height=768,
|
||||
width=768,
|
||||
output_type="np",
|
||||
|
||||
@@ -169,6 +169,12 @@ class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Kandinsky3PipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -186,7 +192,7 @@ class Kandinsky3PipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
image = pipe(prompt, num_inference_steps=5, generator=generator).images[0]
|
||||
|
||||
assert image.size == (1024, 1024)
|
||||
|
||||
@@ -217,7 +223,7 @@ class Kandinsky3PipelineIntegrationTests(unittest.TestCase):
|
||||
image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
||||
prompt = "A painting of the inside of a subway train with tiny raccoons."
|
||||
|
||||
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
|
||||
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=5, generator=generator).images[0]
|
||||
|
||||
assert image.size == (512, 512)
|
||||
|
||||
|
||||
@@ -187,6 +187,12 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Kandinsky3Img2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -209,7 +215,7 @@ class Kandinsky3Img2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
||||
prompt = "A painting of the inside of a subway train with tiny raccoons."
|
||||
|
||||
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
|
||||
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=5, generator=generator).images[0]
|
||||
|
||||
assert image.size == (512, 512)
|
||||
|
||||
|
||||
@@ -32,14 +32,16 @@ from diffusers import (
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class StableDiffusionPanoramaPipelineFastTests(
|
||||
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionPanoramaPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -96,7 +98,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a photo of the dolomites",
|
||||
"generator": generator,
|
||||
|
||||
@@ -779,7 +779,14 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -639,6 +639,12 @@ class PipelineTesterMixin:
|
||||
"`callback_cfg_params = TEXT_TO_IMAGE_CFG_PARAMS.union({'mask', 'masked_image_latents'})`"
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -183,6 +184,18 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, SDFunctionTesterMixin,
|
||||
@skip_mps
|
||||
@require_torch_gpu
|
||||
class TextToVideoSDPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_two_step_model(self):
|
||||
expected_video = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy"
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@@ -26,6 +27,18 @@ from ..test_pipelines_common import assert_mean_pixel_difference
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_full_model(self):
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import inspect
|
||||
import io
|
||||
import re
|
||||
@@ -381,6 +382,18 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_full_model(self):
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user