Compare commits

...

8 Commits

Author SHA1 Message Date
Dhruv Nair
cfa7c0a93d Release: v0.27.0 2024-03-14 15:32:01 +00:00
Dhruv Nair
4974b84564 Update Cascade Tests (#7324)
* update

* update

* update
2024-03-14 20:51:22 +05:30
Linoy Tsaban
83062fb872 [Advanced DreamBooth LoRA SDXL] Support EDM-style training (follow up of #7126) (#7182)
* add edm style training

* style

* finish adding edm training feature

* import fix

* fix latents mean

* minor adjustments

* add edm to readme

* style

* fix autocast and scheduler config issues when using edm

* style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-14 18:40:14 +05:30
Suraj Patil
b6d7e31d10 add edm schedulers in doc (#7319)
* add edm schedulers in doc

* add in toctree

* address reviewe comments
2024-03-14 11:52:25 +01:00
Anatoly Belikov
53e9aacc10 log loss per image (#7278)
* log loss per image

* add commandline param for per image loss logging

* style

* debug-loss -> debug_loss

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-14 11:41:43 +05:30
Dhruv Nair
41424466e3 [Tests] Fix incorrect constant in VAE scaling test. (#7301)
update
2024-03-14 10:24:01 +05:30
Sayak Paul
95de1981c9 add: pytest log installation (#7313) 2024-03-14 10:01:16 +05:30
Kenneth Gerald Hamilton
0b45b58867 update get_order_list if statement (#7309)
* update get_order_list if statement

* revery
2024-03-13 18:29:42 -10:00
46 changed files with 344 additions and 105 deletions

View File

@@ -65,6 +65,7 @@ jobs:
python -m uv pip install -e [quality,test]
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
python -m uv pip install pytest-reportlog
- name: Environment
run: |
@@ -150,6 +151,7 @@ jobs:
${CONDA_RUN} python -m uv pip install -e [quality,test]
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
${CONDA_RUN} python -m uv pip install pytest-reportlog
- name: Environment
shell: arch -arch arm64 bash {0}

View File

@@ -404,6 +404,10 @@
title: EulerAncestralDiscreteScheduler
- local: api/schedulers/euler
title: EulerDiscreteScheduler
- local: api/schedulers/edm_euler
title: EDMEulerScheduler
- local: api/schedulers/edm_multistep_dpm_solver
title: EDMDPMSolverMultistepScheduler
- local: api/schedulers/heun
title: HeunDiscreteScheduler
- local: api/schedulers/ipndm

View File

@@ -0,0 +1,22 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# EDMEulerScheduler
The Karras formulation of the Euler scheduler (Algorithm 2) from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by [Katherine Crowson](https://github.com/crowsonkb/).
## EDMEulerScheduler
[[autodoc]] EDMEulerScheduler
## EDMEulerSchedulerOutput
[[autodoc]] schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput

View File

@@ -0,0 +1,24 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# EDMDPMSolverMultistepScheduler
`EDMDPMSolverMultistepScheduler` is a [Karras formulation](https://huggingface.co/papers/2206.00364) of `DPMSolverMultistep`, a multistep scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.
DPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality
samples, and it can generate quite good samples even in 10 steps.
## EDMDPMSolverMultistepScheduler
[[autodoc]] EDMDPMSolverMultistepScheduler
## SchedulerOutput
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput

View File

@@ -259,6 +259,50 @@ pip install git+https://github.com/huggingface/peft.git
**Inference**
The inference is the same as if you train a regular LoRA 🤗
## Conducting EDM-style training
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
simply set:
```diff
+ --do_edm_style_training \
```
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
```bash
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
--dataset_name="linoyts/3d_icon" \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir="3d-icon-SDXL-LoRA" \
--do_edm_style_training \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=3 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
> [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
### Tips and Tricks
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)

View File

@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
import argparse
import contextlib
import gc
import hashlib
import itertools
import json
import logging
import math
import os
@@ -37,7 +39,7 @@ import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from packaging import version
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
@@ -55,6 +57,8 @@ from diffusers import (
AutoencoderKL,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMEulerScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
@@ -74,11 +78,25 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)
def determine_scheduler_type(pretrained_model_name_or_path, revision):
model_index_filename = "model_index.json"
if os.path.isdir(pretrained_model_name_or_path):
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
else:
model_index = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
)
with open(model_index, "r") as f:
scheduler_type = json.load(f)["scheduler"][1]
return scheduler_type
def save_model_card(
repo_id: str,
use_dora: bool,
@@ -370,6 +388,11 @@ def parse_args(input_args=None):
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--do_edm_style_training",
action="store_true",
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1117,6 +1140,8 @@ def main(args):
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -1234,7 +1259,19 @@ def main(args):
)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
if "EDM" in scheduler_type:
args.do_edm_style_training = True
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
logger.info("Performing EDM-style training!")
elif args.do_edm_style_training:
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
logger.info("Performing EDM-style training!")
else:
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
@@ -1252,7 +1289,12 @@ def main(args):
revision=args.revision,
variant=args.variant,
)
vae_scaling_factor = vae.config.scaling_factor
latents_mean = latents_std = None
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
@@ -1790,6 +1832,19 @@ def main(args):
disable=not accelerator.is_local_main_process,
)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# TODO: revisit other sampling algorithms
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
if args.train_text_encoder:
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
@@ -1841,9 +1896,15 @@ def main(args):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae_scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
if latents_mean is None and latents_std is None:
model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
@@ -1854,15 +1915,32 @@ def main(args):
)
bsz = model_input.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
if not args.do_edm_style_training:
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
else:
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
# instead of discrete timesteps, so here we sample indices to get the noise levels
# from `scheduler.timesteps`
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
if args.do_edm_style_training:
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
if "EDM" in scheduler_type:
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
else:
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
# time ids
add_time_ids = torch.cat(
@@ -1888,7 +1966,7 @@ def main(args):
}
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input,
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
@@ -1906,14 +1984,42 @@ def main(args):
)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample
weighting = None
if args.do_edm_style_training:
# Similar to the input preconditioning, the model predictions are also preconditioned
# on noised model inputs (before preconditioning) and the sigmas.
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
if "EDM" in scheduler_type:
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
else:
if noise_scheduler.config.prediction_type == "epsilon":
model_pred = model_pred * (-sigmas) + noisy_model_input
elif noise_scheduler.config.prediction_type == "v_prediction":
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
noisy_model_input / (sigmas**2 + 1)
)
# We are not doing weighting here because it tends result in numerical problems.
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
# There might be other alternatives for weighting as well:
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
if "EDM" not in scheduler_type:
weighting = (sigmas**-2.0).float()
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
target = model_input if args.do_edm_style_training else noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
target = (
model_input
if args.do_edm_style_training
else noise_scheduler.get_velocity(model_input, noise, timesteps)
)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
@@ -1923,10 +2029,28 @@ def main(args):
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if weighting is not None:
prior_loss = torch.mean(
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
target_prior.shape[0], -1
),
1,
)
prior_loss = prior_loss.mean()
else:
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
if weighting is not None:
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
target.shape[0], -1
),
1,
)
loss = loss.mean()
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
@@ -2049,17 +2173,18 @@ def main(args):
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
@@ -2067,8 +2192,13 @@ def main(args):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = (
contextlib.nullcontext()
if "playground" in args.pretrained_model_name_or_path
else torch.cuda.amp.autocast()
)
with torch.cuda.amp.autocast():
with inference_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
@@ -2144,15 +2274,18 @@ def main(args):
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)

View File

@@ -40,8 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
class MarigoldDepthOutput(BaseOutput):
"""

View File

@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -71,7 +71,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -77,7 +77,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = logging.getLogger(__name__)

View File

@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -75,7 +75,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = logging.getLogger(__name__)

View File

@@ -52,7 +52,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -64,7 +64,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)
@@ -425,6 +425,11 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--debug_loss",
action="store_true",
help="debug loss for each image, if filenames are awailable in the dataset",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -603,6 +608,7 @@ def main(args):
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
unet.to(accelerator.device, dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32)
else:
@@ -890,13 +896,17 @@ def main(args):
tokens_one, tokens_two = tokenize_captions(examples)
examples["input_ids_one"] = tokens_one
examples["input_ids_two"] = tokens_two
if args.debug_loss:
fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
if fnames:
examples["filenames"] = fnames
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
@@ -905,7 +915,7 @@ def main(args):
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
return {
result = {
"pixel_values": pixel_values,
"input_ids_one": input_ids_one,
"input_ids_two": input_ids_two,
@@ -913,6 +923,11 @@ def main(args):
"crop_top_lefts": crop_top_lefts,
}
filenames = [example["filenames"] for example in examples if "filenames" in example]
if filenames:
result["filenames"] = filenames
return result
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
@@ -1105,7 +1120,9 @@ def main(args):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
if args.debug_loss and "filenames" in batch:
for fname in batch["filenames"]:
accelerator.log({"loss_for_" + fname: loss}, step=global_step)
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps

View File

@@ -54,7 +54,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -80,7 +80,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = logging.getLogger(__name__)

View File

@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__)

View File

@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -249,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.27.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.27.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.27.0.dev0"
__version__ = "0.27.0"
from typing import TYPE_CHECKING

View File

@@ -223,6 +223,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
steps = num_inference_steps
order = self.config.solver_order
if order > 3:
raise ValueError("Order > 3 is not supported by this scheduler")
if self.config.lower_order_final:
if order == 3:
if steps % 3 == 0:

View File

@@ -829,7 +829,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
)
assert vae_default.config.scaling_factor == 0.18125
assert vae_default.config.scaling_factor == 0.18215
assert vae_default.config.sample_size == 512
assert vae_default.dtype == torch.float32

View File

@@ -50,9 +50,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade-prior", subfolder="prior", revision="refs/pr/2", variant="bf16"
)
unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder="prior", variant="bf16")
unet_config = unet.config
del unet
gc.collect()
@@ -74,9 +72,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade", subfolder="decoder", revision="refs/pr/44", variant="bf16"
)
unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder="decoder", variant="bf16")
unet_config = unet.config
del unet
gc.collect()

View File

@@ -21,13 +21,13 @@ import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
load_numpy,
load_pt,
numpy_cosine_similarity_distance,
require_torch_gpu,
skip_mps,
slow,
@@ -258,7 +258,7 @@ class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
def test_stable_cascade_decoder(self):
pipe = StableCascadeDecoderPipeline.from_pretrained(
"diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16
"stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
@@ -271,18 +271,16 @@ class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
)
image = pipe(
prompt=prompt, image_embeddings=image_embedding, num_inference_steps=10, generator=generator
prompt=prompt,
image_embeddings=image_embedding,
output_type="np",
num_inference_steps=2,
generator=generator,
).images[0]
assert image.size == (1024, 1024)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/t2i.png"
assert image.shape == (1024, 1024, 3)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/stable_cascade_decoder_image.npy"
)
image_processor = VaeImageProcessor()
image_np = image_processor.pil_to_numpy(image)
expected_image_np = image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=53e-2))
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
assert max_diff < 1e-4

View File

@@ -29,7 +29,8 @@ from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProc
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_pt,
load_numpy,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
skip_mps,
@@ -319,7 +320,9 @@ class StableCascadePriorPipelineIntegrationTests(unittest.TestCase):
torch.cuda.empty_cache()
def test_stable_cascade_prior(self):
pipe = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=torch.bfloat16)
pipe = StableCascadePriorPipeline.from_pretrained(
"stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
@@ -327,17 +330,12 @@ class StableCascadePriorPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(prompt, num_inference_steps=10, generator=generator)
output = pipe(prompt, num_inference_steps=2, output_type="np", generator=generator)
image_embedding = output.image_embeddings
expected_image_embedding = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
expected_image_embedding = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/stable_cascade_prior_image_embeddings.npy"
)
assert image_embedding.shape == (1, 16, 24, 24)
self.assertTrue(
np.allclose(
image_embedding.cpu().float().numpy(), expected_image_embedding.cpu().float().numpy(), atol=5e-2
)
)
max_diff = numpy_cosine_similarity_distance(image_embedding.flatten(), expected_image_embedding.flatten())
assert max_diff < 1e-4