mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
11 Commits
disable-te
...
test-backe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
858dfd6411 | ||
|
|
6cb2178a91 | ||
|
|
f90a5139a2 | ||
|
|
a2bc2e14b9 | ||
|
|
f427345ab1 | ||
|
|
6e221334cd | ||
|
|
53bc30dd45 | ||
|
|
eacf5e34eb | ||
|
|
4c05f7856a | ||
|
|
bbd3572044 | ||
|
|
f948778322 |
8
.github/workflows/pr_test_fetcher.yml
vendored
8
.github/workflows/pr_test_fetcher.yml
vendored
@@ -1,6 +1,12 @@
|
||||
name: Fast tests for PRs - Test Fetcher
|
||||
|
||||
on: workflow_dispatch
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- ci-*
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
|
||||
@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
|
||||
## AttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnProcessor2_0
|
||||
|
||||
## FusedAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
|
||||
|
||||
## LoRAAttnProcessor
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
||||
|
||||
|
||||
@@ -123,16 +123,26 @@ def save_model_card(
|
||||
"""
|
||||
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
diffusers_imports_pivotal = ""
|
||||
diffusers_example_pivotal = ""
|
||||
if train_text_encoder_ti:
|
||||
trigger_str = (
|
||||
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
|
||||
"in you prompt with the new inserted tokens:\n"
|
||||
)
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
"""
|
||||
if token_abstraction_dict:
|
||||
for key, value in token_abstraction_dict.items():
|
||||
tokens = "".join(value)
|
||||
trigger_str += f"""
|
||||
to trigger concept `{key}->` use `{tokens}` in your prompt \n
|
||||
to trigger concept `{key}` → use `{tokens}` in your prompt \n
|
||||
"""
|
||||
|
||||
yaml = f"""
|
||||
@@ -172,7 +182,21 @@ Special VAE used for training: {vae_path}.
|
||||
|
||||
{trigger_str}
|
||||
|
||||
## Download model
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
{diffusers_imports_pivotal}
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
|
||||
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
|
||||
{diffusers_example_pivotal}
|
||||
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
|
||||
```
|
||||
|
||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||
|
||||
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
|
||||
|
||||
Weights for this model are available in Safetensors format.
|
||||
|
||||
@@ -791,6 +815,12 @@ class DreamBoothDataset(Dataset):
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
class_prompt,
|
||||
dataset_name,
|
||||
dataset_config_name,
|
||||
cache_dir,
|
||||
image_column,
|
||||
caption_column,
|
||||
train_text_encoder_ti,
|
||||
class_data_root=None,
|
||||
class_num=None,
|
||||
token_abstraction_dict=None, # token mapping for textual inversion
|
||||
@@ -805,10 +835,10 @@ class DreamBoothDataset(Dataset):
|
||||
self.custom_instance_prompts = None
|
||||
self.class_prompt = class_prompt
|
||||
self.token_abstraction_dict = token_abstraction_dict
|
||||
|
||||
self.train_text_encoder_ti = train_text_encoder_ti
|
||||
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
|
||||
# we load the training data using load_dataset
|
||||
if args.dataset_name is not None:
|
||||
if dataset_name is not None:
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError:
|
||||
@@ -821,26 +851,25 @@ class DreamBoothDataset(Dataset):
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
dataset_name,
|
||||
dataset_config_name,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
# Preprocessing the datasets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
if args.image_column is None:
|
||||
if image_column is None:
|
||||
image_column = column_names[0]
|
||||
logger.info(f"image column defaulting to {image_column}")
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
instance_images = dataset["train"][image_column]
|
||||
|
||||
if args.caption_column is None:
|
||||
if caption_column is None:
|
||||
logger.info(
|
||||
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
|
||||
"contains captions/prompts for the images, make sure to specify the "
|
||||
@@ -848,11 +877,11 @@ class DreamBoothDataset(Dataset):
|
||||
)
|
||||
self.custom_instance_prompts = None
|
||||
else:
|
||||
if args.caption_column not in column_names:
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
custom_instance_prompts = dataset["train"][args.caption_column]
|
||||
custom_instance_prompts = dataset["train"][caption_column]
|
||||
# create final list of captions according to --repeats
|
||||
self.custom_instance_prompts = []
|
||||
for caption in custom_instance_prompts:
|
||||
@@ -907,7 +936,7 @@ class DreamBoothDataset(Dataset):
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
if caption:
|
||||
if args.train_text_encoder_ti:
|
||||
if self.train_text_encoder_ti:
|
||||
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
|
||||
for token_abs, token_replacement in self.token_abstraction_dict.items():
|
||||
caption = caption.replace(token_abs, "".join(token_replacement))
|
||||
@@ -1093,10 +1122,10 @@ def main(args):
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
model_id = args.hub_model_id or Path(args.output_dir).name
|
||||
repo_id = None
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
|
||||
|
||||
# Load the tokenizers
|
||||
tokenizer_one = AutoTokenizer.from_pretrained(
|
||||
@@ -1464,6 +1493,12 @@ def main(args):
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_prompt=args.class_prompt,
|
||||
dataset_name=args.dataset_name,
|
||||
dataset_config_name=args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
image_column=args.image_column,
|
||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||
caption_column=args.caption_column,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
|
||||
class_num=args.num_class_images,
|
||||
@@ -2004,23 +2039,23 @@ def main(args):
|
||||
}
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/embeddings.safetensors",
|
||||
)
|
||||
save_model_card(
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||
token_abstraction_dict=train_dataset.token_abstraction_dict,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
vae_path=args.pretrained_vae_model_name_or_path,
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/embeddings.safetensors",
|
||||
)
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||
token_abstraction_dict=train_dataset.token_abstraction_dict,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
vae_path=args.pretrained_vae_model_name_or_path,
|
||||
)
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
|
||||
@@ -2870,10 +2870,14 @@ The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
|
||||
- `show_image` (`bool`, defaults to False):
|
||||
Determine whether to show intermediate results during generation.
|
||||
```
|
||||
from pipeline_demofusion_sdxl import DemoFusionSDXLPipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = DemoFusionSDXLPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
custom_pipeline="pipeline_demofusion_sdxl",
|
||||
custom_revision="main",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified."
|
||||
|
||||
@@ -36,7 +36,9 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
from diffusers.pipelines.stable_diffusion_xl.watermark import (
|
||||
StableDiffusionXLWatermarker,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
2
setup.py
2
setup.py
@@ -118,7 +118,7 @@ _deps = [
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.8.0",
|
||||
"ruff>=0.1.5,<=0.2",
|
||||
"ruff==0.1.5",
|
||||
"safetensors>=0.3.1",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
"GitPython<3.1.19",
|
||||
|
||||
@@ -30,7 +30,7 @@ deps = {
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.8.0",
|
||||
"ruff": "ruff>=0.1.5,<=0.2",
|
||||
"ruff": "ruff==0.1.5",
|
||||
"safetensors": "safetensors>=0.3.1",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
|
||||
@@ -282,7 +282,7 @@ class FromSingleFileMixin:
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(torch_dtype=torch_dtype)
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -113,12 +113,14 @@ class Attention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
# we make use of this private variable to know whether this class is loaded
|
||||
@@ -180,6 +182,7 @@ class Attention(nn.Module):
|
||||
else:
|
||||
linear_cls = LoRACompatibleLinear
|
||||
|
||||
self.linear_cls = linear_cls
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
@@ -692,6 +695,32 @@ class Attention(nn.Module):
|
||||
|
||||
return encoder_hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self, fuse=True):
|
||||
is_cross_attention = self.cross_attention_dim != self.query_dim
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
if not is_cross_attention:
|
||||
# fetch weight matrices.
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
# create a new single projection layer and copy over the weights.
|
||||
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
|
||||
self.fused_projections = fuse
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
@@ -1184,9 +1213,6 @@ class AttnProcessor2_0:
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -1253,6 +1279,103 @@ class AttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is currently 🧪 experimental in nature and can change in future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
if encoder_hidden_states is None:
|
||||
qkv = attn.to_qkv(hidden_states, *args)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
else:
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
kv = attn.to_kv(encoder_hidden_states, *args)
|
||||
split_size = kv.shape[-1] // 2
|
||||
key, value = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
||||
@@ -2251,6 +2374,7 @@ CROSS_ATTENTION_PROCESSORS = (
|
||||
AttentionProcessor = Union[
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
|
||||
@@ -22,6 +22,7 @@ from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
@@ -448,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@@ -25,6 +25,7 @@ from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
98
src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py
Executable file
98
src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py
Executable file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import fnmatch
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import Kandinsky3UNet
|
||||
|
||||
|
||||
MAPPING = {
|
||||
"to_time_embed.1": "time_embedding.linear_1",
|
||||
"to_time_embed.3": "time_embedding.linear_2",
|
||||
"in_layer": "conv_in",
|
||||
"out_layer.0": "conv_norm_out",
|
||||
"out_layer.2": "conv_out",
|
||||
"down_samples": "down_blocks",
|
||||
"up_samples": "up_blocks",
|
||||
"projection_lin": "encoder_hid_proj.projection_linear",
|
||||
"projection_ln": "encoder_hid_proj.projection_norm",
|
||||
"feature_pooling": "add_time_condition",
|
||||
"to_query": "to_q",
|
||||
"to_key": "to_k",
|
||||
"to_value": "to_v",
|
||||
"output_layer": "to_out.0",
|
||||
"self_attention_block": "attentions.0",
|
||||
}
|
||||
|
||||
DYNAMIC_MAP = {
|
||||
"resnet_attn_blocks.*.0": "resnets_in.*",
|
||||
"resnet_attn_blocks.*.1": ("attentions.*", 1),
|
||||
"resnet_attn_blocks.*.2": "resnets_out.*",
|
||||
}
|
||||
# MAPPING = {}
|
||||
|
||||
|
||||
def convert_state_dict(unet_state_dict):
|
||||
"""
|
||||
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
|
||||
Args:
|
||||
unet_model (torch.nn.Module): The original U-Net model.
|
||||
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
|
||||
|
||||
Returns:
|
||||
OrderedDict: The converted state dictionary.
|
||||
"""
|
||||
# Example of renaming logic (this will vary based on your model's architecture)
|
||||
converted_state_dict = {}
|
||||
for key in unet_state_dict:
|
||||
new_key = key
|
||||
for pattern, new_pattern in MAPPING.items():
|
||||
new_key = new_key.replace(pattern, new_pattern)
|
||||
|
||||
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
|
||||
has_matched = False
|
||||
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
|
||||
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
|
||||
|
||||
if isinstance(dyn_new_pattern, tuple):
|
||||
new_star = star + dyn_new_pattern[-1]
|
||||
dyn_new_pattern = dyn_new_pattern[0]
|
||||
else:
|
||||
new_star = star
|
||||
|
||||
pattern = dyn_pattern.replace("*", str(star))
|
||||
new_pattern = dyn_new_pattern.replace("*", str(new_star))
|
||||
|
||||
new_key = new_key.replace(pattern, new_pattern)
|
||||
has_matched = True
|
||||
|
||||
converted_state_dict[new_key] = unet_state_dict[key]
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def main(model_path, output_path):
|
||||
# Load your original U-Net model
|
||||
unet_state_dict = load_file(model_path)
|
||||
|
||||
# Initialize your Kandinsky3UNet model
|
||||
config = {}
|
||||
|
||||
# Convert the state dict
|
||||
converted_state_dict = convert_state_dict(unet_state_dict)
|
||||
|
||||
unet = Kandinsky3UNet(config)
|
||||
unet.load_state_dict(converted_state_dict)
|
||||
|
||||
unet.save_pretrained(output_path)
|
||||
print(f"Converted model saved to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args.model_path, args.output_path)
|
||||
@@ -446,7 +446,7 @@ def convert_ldm_unet_checkpoint(
|
||||
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||
|
||||
# Relevant to StableDiffusionUpscalePipeline
|
||||
if "num_class_embeds" in config:
|
||||
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
|
||||
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
|
||||
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
|
||||
@@ -34,6 +34,7 @@ from ...loaders import (
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -681,7 +682,6 @@ class StableDiffusionXLPipeline(
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
self.vae.to(dtype=torch.float32)
|
||||
@@ -692,6 +692,7 @@ class StableDiffusionXLPipeline(
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
@@ -729,6 +730,65 @@ class StableDiffusionXLPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, Te
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -610,6 +611,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
|
||||
@@ -10,10 +10,10 @@ from diffusers.utils import deprecate
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...models.activations import get_activation
|
||||
from ...models.attention import Attention
|
||||
from ...models.attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
@@ -1000,6 +1000,42 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -191,10 +191,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
|
||||
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||
return self.sigmas.max()
|
||||
return max_sigma
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
return (max_sigma**2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -289,6 +290,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
if sigmas.device.type == "cuda":
|
||||
self.sigmas = self.sigmas.tolist()
|
||||
self._step_index = None
|
||||
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
|
||||
@@ -17,7 +17,7 @@ from contextlib import contextmanager
|
||||
from distutils.util import strtobool
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -58,6 +58,17 @@ USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
# Set a backend environment variable for any extra module import required for a custom accelerator
|
||||
if "DIFFUSERS_TEST_BACKEND" in os.environ:
|
||||
backend = os.environ["DIFFUSERS_TEST_BACKEND"]
|
||||
try:
|
||||
_ = importlib.import_module(backend)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
|
||||
to enable a specified backend.):\n{e}"
|
||||
) from e
|
||||
|
||||
if "DIFFUSERS_TEST_DEVICE" in os.environ:
|
||||
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
|
||||
try:
|
||||
@@ -210,6 +221,36 @@ def require_torch_gpu(test_case):
|
||||
)
|
||||
|
||||
|
||||
# These decorators are for accelerator-specific behaviours that are not GPU-specific
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp64(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_training(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for training."""
|
||||
return unittest.skipUnless(
|
||||
is_torch_available() and backend_supports_training(torch_device),
|
||||
"test requires accelerator with training support",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def skip_mps(test_case):
|
||||
"""Decorator marking a test to skip if torch_device is 'mps'"""
|
||||
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
|
||||
@@ -766,3 +807,139 @@ def disable_full_determinism():
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
# Utils for custom and alternative accelerator devices
|
||||
def _is_torch_fp16_available(device):
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
import torch
|
||||
|
||||
device = torch.device(device)
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
|
||||
_ = x @ x
|
||||
except Exception as e:
|
||||
if device.type == "cuda":
|
||||
raise ValueError(
|
||||
f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_torch_fp64_available(device):
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
|
||||
_ = x @ x
|
||||
except Exception as e:
|
||||
if device.type == "cuda":
|
||||
raise ValueError(
|
||||
f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
|
||||
if is_torch_available():
|
||||
# Behaviour flags
|
||||
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
|
||||
|
||||
# Function definitions
|
||||
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
|
||||
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
|
||||
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
|
||||
|
||||
|
||||
# This dispatches a defined function according to the accelerator from the function definitions.
|
||||
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
|
||||
if device not in dispatch_table:
|
||||
return dispatch_table["default"](*args, **kwargs)
|
||||
|
||||
fn = dispatch_table[device]
|
||||
|
||||
# Some device agnostic functions return values. Need to guard against 'None' instead at
|
||||
# user level
|
||||
if fn is None:
|
||||
return None
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
# These are callables which automatically dispatch the function specific to the accelerator
|
||||
def backend_manual_seed(device: str, seed: int):
|
||||
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
|
||||
|
||||
|
||||
def backend_empty_cache(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
|
||||
|
||||
|
||||
def backend_device_count(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
|
||||
|
||||
|
||||
# These are callables which return boolean behaviour flags and can be used to specify some
|
||||
# device agnostic alternative where the feature is unsupported.
|
||||
def backend_supports_training(device: str):
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
if device not in BACKEND_SUPPORTS_TRAINING:
|
||||
device = "default"
|
||||
|
||||
return BACKEND_SUPPORTS_TRAINING[device]
|
||||
|
||||
|
||||
# Guard for when Torch is not available
|
||||
if is_torch_available():
|
||||
# Update device function dict mapping
|
||||
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
|
||||
try:
|
||||
# Try to import the function directly
|
||||
spec_fn = getattr(device_spec_module, attribute_name)
|
||||
device_fn_dict[torch_device] = spec_fn
|
||||
except AttributeError as e:
|
||||
# If the function doesn't exist, and there is no default, throw an error
|
||||
if "default" not in device_fn_dict:
|
||||
raise AttributeError(
|
||||
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
|
||||
) from e
|
||||
|
||||
if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
|
||||
device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
|
||||
if not Path(device_spec_path).is_file():
|
||||
raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
|
||||
|
||||
try:
|
||||
import_name = device_spec_path[: device_spec_path.index(".py")]
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
|
||||
|
||||
device_spec_module = importlib.import_module(import_name)
|
||||
|
||||
try:
|
||||
device_name = device_spec_module.DEVICE_NAME
|
||||
except AttributeError:
|
||||
raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
|
||||
|
||||
if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
|
||||
msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
|
||||
msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
|
||||
raise ValueError(msg)
|
||||
|
||||
torch_device = device_name
|
||||
|
||||
# Add one entry here for each `BACKEND_*` dictionary.
|
||||
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
|
||||
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
|
||||
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
|
||||
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import fnmatch
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import Kandinsky3UNet
|
||||
|
||||
|
||||
MAPPING = {
|
||||
"to_time_embed.1": "time_embedding.linear_1",
|
||||
"to_time_embed.3": "time_embedding.linear_2",
|
||||
"in_layer": "conv_in",
|
||||
"out_layer.0": "conv_norm_out",
|
||||
"out_layer.2": "conv_out",
|
||||
"down_samples": "down_blocks",
|
||||
"up_samples": "up_blocks",
|
||||
"projection_lin": "encoder_hid_proj.projection_linear",
|
||||
"projection_ln": "encoder_hid_proj.projection_norm",
|
||||
"feature_pooling": "add_time_condition",
|
||||
"to_query": "to_q",
|
||||
"to_key": "to_k",
|
||||
"to_value": "to_v",
|
||||
"output_layer": "to_out.0",
|
||||
"self_attention_block": "attentions.0",
|
||||
}
|
||||
|
||||
DYNAMIC_MAP = {
|
||||
"resnet_attn_blocks.*.0": "resnets_in.*",
|
||||
"resnet_attn_blocks.*.1": ("attentions.*", 1),
|
||||
"resnet_attn_blocks.*.2": "resnets_out.*",
|
||||
}
|
||||
# MAPPING = {}
|
||||
|
||||
|
||||
def convert_state_dict(unet_state_dict):
|
||||
"""
|
||||
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
|
||||
Args:
|
||||
unet_model (torch.nn.Module): The original U-Net model.
|
||||
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
|
||||
|
||||
Returns:
|
||||
OrderedDict: The converted state dictionary.
|
||||
"""
|
||||
# Example of renaming logic (this will vary based on your model's architecture)
|
||||
converted_state_dict = {}
|
||||
for key in unet_state_dict:
|
||||
new_key = key
|
||||
for pattern, new_pattern in MAPPING.items():
|
||||
new_key = new_key.replace(pattern, new_pattern)
|
||||
|
||||
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
|
||||
has_matched = False
|
||||
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
|
||||
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
|
||||
|
||||
if isinstance(dyn_new_pattern, tuple):
|
||||
new_star = star + dyn_new_pattern[-1]
|
||||
dyn_new_pattern = dyn_new_pattern[0]
|
||||
else:
|
||||
new_star = star
|
||||
|
||||
pattern = dyn_pattern.replace("*", str(star))
|
||||
new_pattern = dyn_new_pattern.replace("*", str(new_star))
|
||||
|
||||
new_key = new_key.replace(pattern, new_pattern)
|
||||
has_matched = True
|
||||
|
||||
converted_state_dict[new_key] = unet_state_dict[key]
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def main(model_path, output_path):
|
||||
# Load your original U-Net model
|
||||
unet_state_dict = load_file(model_path)
|
||||
|
||||
# Initialize your Kandinsky3UNet model
|
||||
config = {}
|
||||
|
||||
# Convert the state dict
|
||||
converted_state_dict = convert_state_dict(unet_state_dict)
|
||||
|
||||
unet = Kandinsky3UNet(config)
|
||||
unet.load_state_dict(converted_state_dict)
|
||||
|
||||
unet.save_pretrained(output_path)
|
||||
print(f"Converted model saved to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args.model_path, args.output_path)
|
||||
|
||||
@@ -25,7 +25,11 @@ from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.lora import LoRACompatibleLinear
|
||||
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from diffusers.models.transformer_2d import Transformer2DModel
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
require_torch_accelerator_with_fp64,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingsTests(unittest.TestCase):
|
||||
@@ -315,8 +319,7 @@ class ResnetBlock2DTests(unittest.TestCase):
|
||||
class Transformer2DModelTests(unittest.TestCase):
|
||||
def test_spatial_transformer_default(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
@@ -339,8 +342,7 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
|
||||
def test_spatial_transformer_cross_attention_dim(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
sample = torch.randn(1, 64, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
@@ -363,8 +365,7 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
|
||||
def test_spatial_transformer_timestep(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
num_embeds_ada_norm = 5
|
||||
|
||||
@@ -401,8 +402,7 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
|
||||
def test_spatial_transformer_dropout(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = (
|
||||
@@ -427,11 +427,10 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
|
||||
@require_torch_accelerator_with_fp64
|
||||
def test_spatial_transformer_discrete(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
num_embed = 5
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
torch_device,
|
||||
@@ -536,7 +537,7 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -553,7 +554,7 @@ class ModelTesterMixin:
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_ema_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -624,7 +625,7 @@ class ModelTesterMixin:
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_enable_disable_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
@@ -21,7 +21,14 @@ import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import PriorTransformer
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, slow, torch_all_close, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
@@ -157,7 +164,7 @@ class PriorTransformerIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
|
||||
@@ -18,7 +18,12 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet1DModel
|
||||
from diffusers.utils.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
floats_tensor,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
@@ -103,8 +108,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_output_pretrained(self):
|
||||
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
num_features = model.config.in_channels
|
||||
seq_len = 16
|
||||
@@ -244,8 +248,7 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
num_features = value_function.config.in_channels
|
||||
seq_len = 14
|
||||
|
||||
@@ -24,6 +24,7 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
@@ -153,7 +154,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
@require_torch_accelerator
|
||||
def test_from_pretrained_accelerate(self):
|
||||
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model.to(torch_device)
|
||||
@@ -161,7 +162,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
@require_torch_accelerator
|
||||
def test_from_pretrained_accelerate_wont_change_results(self):
|
||||
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
|
||||
@@ -30,10 +30,15 @@ from diffusers.models.embeddings import ImageProjection, Resampler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
@@ -280,7 +285,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -864,7 +869,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
@@ -882,6 +887,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
return model
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attention_slice_auto(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -901,6 +907,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attention_slice_max(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -920,6 +927,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attention_slice_int(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -939,6 +947,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attention_slice_list(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -975,7 +984,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
|
||||
latents = self.get_latents(seed)
|
||||
@@ -1003,7 +1012,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
@@ -1031,7 +1040,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
|
||||
latents = self.get_latents(seed)
|
||||
@@ -1059,7 +1069,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
@@ -1087,7 +1097,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
|
||||
@@ -1115,7 +1126,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
|
||||
@@ -1143,7 +1154,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
|
||||
|
||||
@@ -31,10 +31,15 @@ from diffusers import (
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.loading_utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
@@ -157,7 +162,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -213,10 +218,12 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model = model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
# Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
if torch_device != "mps":
|
||||
generator = torch.Generator(device=generator_device).manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = torch.randn(
|
||||
1,
|
||||
@@ -247,7 +254,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
-9.8644e-03,
|
||||
]
|
||||
)
|
||||
elif torch_device == "cpu":
|
||||
elif generator_device == "cpu":
|
||||
expected_output_slice = torch.tensor(
|
||||
[
|
||||
-0.1352,
|
||||
@@ -478,7 +485,7 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
@@ -558,7 +565,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
@@ -580,9 +587,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
if torch_device == "mps":
|
||||
return torch.manual_seed(seed)
|
||||
return torch.Generator(device=torch_device).manual_seed(seed)
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
@@ -623,7 +631,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_stable_diffusion_fp16(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
image = self.get_sd_image(seed, fp16=True)
|
||||
@@ -677,7 +685,8 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
def test_stable_diffusion_decode(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
@@ -700,7 +709,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||
@@ -811,7 +820,7 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
@@ -832,9 +841,10 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
if torch_device == "mps":
|
||||
return torch.manual_seed(seed)
|
||||
return torch.Generator(device=torch_device).manual_seed(seed)
|
||||
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
|
||||
if torch_device != "mps":
|
||||
return torch.Generator(device=generator_device).manual_seed(seed)
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
@@ -905,7 +915,8 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
def test_stable_diffusion_decode(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
@@ -18,7 +18,12 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import VQModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
@@ -80,8 +85,7 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model.to(torch_device).eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
image = image.to(torch_device)
|
||||
|
||||
@@ -12,12 +12,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_torch, torch_all_close, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_torch,
|
||||
require_torch_accelerator_with_training,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
@@ -104,7 +109,7 @@ class UNetBlockTesterMixin:
|
||||
expected_slice = torch.tensor(expected_slice).to(torch_device)
|
||||
assert torch_all_close(output_slice.flatten(), expected_slice, atol=5e-3)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.block_class(**init_dict)
|
||||
|
||||
@@ -34,11 +34,14 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -128,10 +131,12 @@ class StableDiffusion2PipelineFastTests(
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
generator_device = "cpu" if not device.startswith("cuda") else "cuda"
|
||||
if not str(device).startswith("mps"):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
@@ -299,15 +304,21 @@ class StableDiffusion2PipelineFastTests(
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
class StableDiffusion2PipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
|
||||
if not str(device).startswith("mps"):
|
||||
generator = torch.Generator(device=_generator_device).manual_seed(seed)
|
||||
else:
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||
inputs = {
|
||||
@@ -361,6 +372,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.10440, 0.13115, 0.11100, 0.10141, 0.11440, 0.07215, 0.11332, 0.09693, 0.10006])
|
||||
assert np.abs(image_slice - expected_slice).max() < 3e-3
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_attention_slicing(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
@@ -432,6 +444,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
|
||||
assert callback_fn.has_been_called
|
||||
assert number_of_steps == inputs["num_inference_steps"]
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -452,6 +465,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
|
||||
# make sure that less than 2.8 GB is allocated
|
||||
assert mem_bytes < 2.8 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_pipeline_with_model_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -511,15 +525,21 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
|
||||
if not str(device).startswith("mps"):
|
||||
generator = torch.Generator(device=_generator_device).manual_seed(seed)
|
||||
else:
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||
inputs = {
|
||||
|
||||
@@ -938,6 +938,37 @@ class StableDiffusionXLPipelineFastTests(
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_xl_with_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
sd_pipe.fuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
sd_pipe.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user