mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-25 13:54:45 +08:00
Compare commits
14 Commits
test-runne
...
single-fil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c699800cd | ||
|
|
04bafcbbc2 | ||
|
|
738df86a7b | ||
|
|
158ed3f28a | ||
|
|
7081a25618 | ||
|
|
848f9fe6ce | ||
|
|
8a692739c0 | ||
|
|
5aa31bd674 | ||
|
|
88aa7f6ebf | ||
|
|
ad310af0d6 | ||
|
|
fc1dbf5dd9 | ||
|
|
36e8fbc2cc | ||
|
|
d603ccb614 | ||
|
|
fd0f469568 |
38
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
38
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -66,32 +66,32 @@ body:
|
||||
Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...):
|
||||
|
||||
Questions on pipelines:
|
||||
- Stable Diffusion @yiyixuxu @DN6 @sayakpaul @patrickvonplaten
|
||||
- Stable Diffusion XL @yiyixuxu @sayakpaul @DN6 @patrickvonplaten
|
||||
- Kandinsky @yiyixuxu @patrickvonplaten
|
||||
- ControlNet @sayakpaul @yiyixuxu @DN6 @patrickvonplaten
|
||||
- T2I Adapter @sayakpaul @yiyixuxu @DN6 @patrickvonplaten
|
||||
- IF @DN6 @patrickvonplaten
|
||||
- Text-to-Video / Video-to-Video @DN6 @sayakpaul @patrickvonplaten
|
||||
- Wuerstchen @DN6 @patrickvonplaten
|
||||
- Stable Diffusion @yiyixuxu @DN6 @sayakpaul
|
||||
- Stable Diffusion XL @yiyixuxu @sayakpaul @DN6
|
||||
- Kandinsky @yiyixuxu
|
||||
- ControlNet @sayakpaul @yiyixuxu @DN6
|
||||
- T2I Adapter @sayakpaul @yiyixuxu @DN6
|
||||
- IF @DN6
|
||||
- Text-to-Video / Video-to-Video @DN6 @sayakpaul
|
||||
- Wuerstchen @DN6
|
||||
- Other: @yiyixuxu @DN6
|
||||
|
||||
Questions on models:
|
||||
- UNet @DN6 @yiyixuxu @sayakpaul @patrickvonplaten
|
||||
- VAE @sayakpaul @DN6 @yiyixuxu @patrickvonplaten
|
||||
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6 @patrickvonplaten
|
||||
- UNet @DN6 @yiyixuxu @sayakpaul
|
||||
- VAE @sayakpaul @DN6 @yiyixuxu
|
||||
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6
|
||||
|
||||
Questions on Schedulers: @yiyixuxu @patrickvonplaten
|
||||
Questions on Schedulers: @yiyixuxu
|
||||
|
||||
Questions on LoRA: @sayakpaul @patrickvonplaten
|
||||
Questions on LoRA: @sayakpaul
|
||||
|
||||
Questions on Textual Inversion: @sayakpaul @patrickvonplaten
|
||||
Questions on Textual Inversion: @sayakpaul
|
||||
|
||||
Questions on Training:
|
||||
- DreamBooth @sayakpaul @patrickvonplaten
|
||||
- Text-to-Image Fine-tuning @sayakpaul @patrickvonplaten
|
||||
- Textual Inversion @sayakpaul @patrickvonplaten
|
||||
- ControlNet @sayakpaul @patrickvonplaten
|
||||
- DreamBooth @sayakpaul
|
||||
- Text-to-Image Fine-tuning @sayakpaul
|
||||
- Textual Inversion @sayakpaul
|
||||
- ControlNet @sayakpaul
|
||||
|
||||
Questions on Tests: @DN6 @sayakpaul @yiyixuxu
|
||||
|
||||
@@ -99,7 +99,7 @@ body:
|
||||
|
||||
Questions on JAX- and MPS-related things: @pcuenca
|
||||
|
||||
Questions on audio pipelines: @DN6 @patrickvonplaten
|
||||
Questions on audio pipelines: @DN6
|
||||
|
||||
|
||||
|
||||
|
||||
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -38,13 +38,13 @@ members/contributors who may be interested in your PR.
|
||||
|
||||
Core library:
|
||||
|
||||
- Schedulers: @yiyixuxu and @patrickvonplaten
|
||||
- Pipelines: @patrickvonplaten and @sayakpaul
|
||||
- Training examples: @sayakpaul and @patrickvonplaten
|
||||
- Docs: @stevhliu and @yiyixuxu
|
||||
- Schedulers: @yiyixuxu
|
||||
- Pipelines: @sayakpaul @yiyixuxu @DN6
|
||||
- Training examples: @sayakpaul
|
||||
- Docs: @stevhliu and @sayakpaul
|
||||
- JAX and MPS: @pcuenca
|
||||
- Audio: @sanchit-gandhi
|
||||
- General functionalities: @patrickvonplaten and @sayakpaul
|
||||
- General functionalities: @sayakpaul @yiyixuxu @DN6
|
||||
|
||||
Integrations:
|
||||
|
||||
|
||||
@@ -23,13 +23,13 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --upgrade --no-cache-dir \
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv && \
|
||||
python3 -m uv pip install --upgrade --no-cache-dir \
|
||||
clu \
|
||||
"jax[cpu]>=0.2.16,!=0.3.2" \
|
||||
"flax>=0.4.1" \
|
||||
"jaxlib>=0.1.65" && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
|
||||
@@ -23,15 +23,15 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
"jax[tpu]>=0.2.16,!=0.3.2" \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
|
||||
python3 -m pip install --upgrade --no-cache-dir \
|
||||
python3 -m uv pip install --upgrade --no-cache-dir \
|
||||
clu \
|
||||
"flax>=0.4.1" \
|
||||
"jaxlib>=0.1.65" && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
|
||||
@@ -22,14 +22,14 @@ RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
onnxruntime \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
|
||||
@@ -22,14 +22,14 @@ RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
"onnxruntime-gpu>=1.13.1" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117 && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
|
||||
@@ -24,8 +24,8 @@ RUN python3.9 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
RUN python3.9 -m pip install --no-cache-dir --upgrade pip uv && \
|
||||
python3.9 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
|
||||
@@ -23,14 +23,14 @@ RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
|
||||
@@ -23,8 +23,8 @@ RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
|
||||
@@ -23,13 +23,13 @@ RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
|
||||
@@ -113,7 +113,7 @@ pipe.enable_xformers_memory_efficient_attention()
|
||||
# memory optimization.
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image("./conditioning_image_1.png")
|
||||
control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
|
||||
prompt = "pale golden rod circle with old lace background"
|
||||
|
||||
# generate image
|
||||
@@ -128,4 +128,14 @@ image.save("./output.png")
|
||||
|
||||
### Specifying a better VAE
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:
|
||||
|
||||
```diff
|
||||
+ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
|
||||
+ vae=vae,
|
||||
)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -74,10 +76,15 @@ def image_grid(imgs, rows, cols):
|
||||
return grid
|
||||
|
||||
|
||||
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
|
||||
def log_validation(
|
||||
vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
|
||||
):
|
||||
logger.info("Running validation... ")
|
||||
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
if not is_final_validation:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
else:
|
||||
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
|
||||
|
||||
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
@@ -118,6 +125,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
|
||||
)
|
||||
|
||||
image_logs = []
|
||||
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
@@ -125,7 +133,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
|
||||
images = []
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.autocast("cuda"):
|
||||
with inference_ctx:
|
||||
image = pipeline(
|
||||
validation_prompt, validation_image, num_inference_steps=20, generator=generator
|
||||
).images[0]
|
||||
@@ -136,6 +144,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
|
||||
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
|
||||
)
|
||||
|
||||
tracker_key = "test" if is_final_validation else "validation"
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
for log in image_logs:
|
||||
@@ -167,10 +176,14 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
|
||||
image = wandb.Image(image, caption=validation_prompt)
|
||||
formatted_images.append(image)
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
tracker.log({tracker_key: formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return image_logs
|
||||
|
||||
|
||||
@@ -197,7 +210,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
||||
img_str = ""
|
||||
if image_logs is not None:
|
||||
img_str = "You can find some example images below.\n"
|
||||
img_str = "You can find some example images below.\n\n"
|
||||
for i, log in enumerate(image_logs):
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
@@ -1131,6 +1144,22 @@ def main(args):
|
||||
controlnet = unwrap_model(controlnet)
|
||||
controlnet.save_pretrained(args.output_dir)
|
||||
|
||||
# Run a final round of validation.
|
||||
image_logs = None
|
||||
if args.validation_prompt is not None:
|
||||
image_logs = log_validation(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
controlnet=None,
|
||||
args=args,
|
||||
accelerator=accelerator,
|
||||
weight_dtype=weight_dtype,
|
||||
step=global_step,
|
||||
is_final_validation=True,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
@@ -65,20 +66,38 @@ check_min_version("0.27.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
|
||||
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
|
||||
logger.info("Running validation... ")
|
||||
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
if not is_final_validation:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
else:
|
||||
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
|
||||
if args.pretrained_vae_model_name_or_path is not None:
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
controlnet=controlnet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
@@ -106,6 +125,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
)
|
||||
|
||||
image_logs = []
|
||||
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
@@ -114,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
images = []
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.autocast("cuda"):
|
||||
with inference_ctx:
|
||||
image = pipeline(
|
||||
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
|
||||
).images[0]
|
||||
@@ -124,6 +144,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
|
||||
)
|
||||
|
||||
tracker_key = "test" if is_final_validation else "validation"
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
for log in image_logs:
|
||||
@@ -155,7 +176,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
image = wandb.Image(image, caption=validation_prompt)
|
||||
formatted_images.append(image)
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
tracker.log({tracker_key: formatted_images})
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
@@ -189,7 +210,7 @@ def import_model_class_from_model_name_or_path(
|
||||
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
||||
img_str = ""
|
||||
if image_logs is not None:
|
||||
img_str = "You can find some example images below.\n"
|
||||
img_str = "You can find some example images below.\n\n"
|
||||
for i, log in enumerate(image_logs):
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
@@ -1228,7 +1249,13 @@ def main(args):
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
image_logs = log_validation(
|
||||
vae, unet, controlnet, args, accelerator, weight_dtype, global_step
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
args=args,
|
||||
accelerator=accelerator,
|
||||
weight_dtype=weight_dtype,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -1244,6 +1271,21 @@ def main(args):
|
||||
controlnet = unwrap_model(controlnet)
|
||||
controlnet.save_pretrained(args.output_dir)
|
||||
|
||||
# Run a final round of validation.
|
||||
# Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
|
||||
image_logs = None
|
||||
if args.validation_prompt is not None:
|
||||
image_logs = log_validation(
|
||||
vae=None,
|
||||
unet=None,
|
||||
controlnet=None,
|
||||
args=args,
|
||||
accelerator=accelerator,
|
||||
weight_dtype=weight_dtype,
|
||||
step=global_step,
|
||||
is_final_validation=True,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
|
||||
@@ -951,6 +951,9 @@ def main(args):
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
ema_unet.to(accelerator.device)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
@@ -1126,6 +1129,8 @@ def main(args):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema_unet.step(unet.parameters())
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
|
||||
@@ -546,6 +546,8 @@ class TextualInversionDataset(Dataset):
|
||||
|
||||
example["original_size"] = (image.height, image.width)
|
||||
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
if self.center_crop:
|
||||
y1 = max(0, int(round((image.height - self.size) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - self.size) / 2.0)))
|
||||
@@ -576,7 +578,6 @@ class TextualInversionDataset(Dataset):
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
|
||||
@@ -4,6 +4,7 @@ import math
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from audio_diffusion.models import DiffusionAttnUnet1D
|
||||
from diffusion import sampling
|
||||
@@ -73,9 +74,14 @@ class DiffusionUncond(nn.Module):
|
||||
|
||||
def download(model_name):
|
||||
url = MODELS_MAP[model_name]["url"]
|
||||
os.system(f"wget {url} ./")
|
||||
r = requests.get(url, stream=True)
|
||||
|
||||
return f"./{model_name}.ckpt"
|
||||
local_filename = f"./{model_name}.ckpt"
|
||||
with open(local_filename, "wb") as fp:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
fp.write(chunk)
|
||||
|
||||
return local_filename
|
||||
|
||||
|
||||
DOWN_NUM_TO_LAYER = {
|
||||
|
||||
@@ -106,6 +106,10 @@ class LoraLoaderMixin:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
@@ -1229,6 +1233,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
|
||||
@@ -361,16 +361,19 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
if self.lora_layer is None:
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||
return F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
if self.padding_mode != "zeros":
|
||||
hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode)
|
||||
padding = (0, 0)
|
||||
else:
|
||||
padding = self.padding
|
||||
|
||||
original_outputs = F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups
|
||||
)
|
||||
|
||||
if self.lora_layer is None:
|
||||
return original_outputs
|
||||
else:
|
||||
original_outputs = F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
return original_outputs + (scale * self.lora_layer(hidden_states))
|
||||
|
||||
|
||||
|
||||
@@ -436,7 +436,6 @@ def load_sub_model(
|
||||
variant: str,
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
revision: str = None,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
@@ -504,6 +503,7 @@ def load_sub_model(
|
||||
loading_kwargs["offload_folder"] = offload_folder
|
||||
loading_kwargs["offload_state_dict"] = offload_state_dict
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
@@ -1280,7 +1280,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant=variant,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
revision=revision,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
|
||||
@@ -26,11 +26,13 @@ import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForText2Image,
|
||||
ControlNetModel,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
@@ -1177,6 +1179,24 @@ class PeftLoraLoaderMixinTests:
|
||||
# Just makes sure it works..
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
def test_modify_padding_mode(self):
|
||||
def set_pad_mode(network, mode="circular"):
|
||||
for _, module in network.named_modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
module.padding_mode = mode
|
||||
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, _, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_pad_mode = "circular"
|
||||
set_pad_mode(pipe.vae, _pad_mode)
|
||||
set_pad_mode(pipe.unet, _pad_mode)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
_ = pipe(**inputs).images
|
||||
|
||||
|
||||
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
@@ -1727,6 +1747,40 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
|
||||
release_memory(pipe)
|
||||
|
||||
def test_not_empty_state_dict(self):
|
||||
# Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors")
|
||||
lcm_lora = load_file(cached_file)
|
||||
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertTrue(lcm_lora != {})
|
||||
release_memory(pipe)
|
||||
|
||||
def test_load_unload_load_state_dict(self):
|
||||
# Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors")
|
||||
lcm_lora = load_file(cached_file)
|
||||
previous_state_dict = lcm_lora.copy()
|
||||
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertDictEqual(lcm_lora, previous_state_dict)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertDictEqual(lcm_lora, previous_state_dict)
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -810,6 +810,43 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
vae_single_file = AutoencoderKL.from_single_file(
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
|
||||
for param_name, param_value in vae_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_single_file_arguments(self):
|
||||
vae_default = AutoencoderKL.from_single_file(
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
|
||||
)
|
||||
|
||||
assert vae_default.config.scaling_factor == 0.18125
|
||||
assert vae_default.config.sample_size == 512
|
||||
assert vae_default.dtype == torch.float32
|
||||
|
||||
scaling_factor = 2.0
|
||||
image_size = 256
|
||||
torch_dtype = torch.float16
|
||||
|
||||
vae = AutoencoderKL.from_single_file(
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
|
||||
image_size=image_size,
|
||||
scaling_factor=scaling_factor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
assert vae.config.scaling_factor == scaling_factor
|
||||
assert vae.config.sample_size == image_size
|
||||
assert vae.dtype == torch_dtype
|
||||
|
||||
|
||||
@slow
|
||||
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -1072,6 +1072,44 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", variant="fp16", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet_single_file = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
single_file_pipe = StableDiffusionControlNetPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet_single_file,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.controlnet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.controlnet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -863,6 +863,49 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten())
|
||||
assert max_diff < 5e-2
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
variant="fp16",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
single_file_url = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
single_file_pipe = StableDiffusionXLControlNetPipeline.from_single_file(
|
||||
single_file_url, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
|
||||
class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests):
|
||||
def test_controlnet_sdxl_guess(self):
|
||||
|
||||
@@ -1295,6 +1295,39 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt"
|
||||
single_file_pipe = StableDiffusionPipeline.from_single_file(ckpt_path, load_safety_checker=True)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.safety_checker.config.to_dict()[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -785,6 +785,39 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", variant="fp16")
|
||||
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
|
||||
single_file_pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path, load_safety_checker=True)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.safety_checker.config.to_dict()[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -513,3 +513,40 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||
assert (
|
||||
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
|
||||
)
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", variant="fp16"
|
||||
)
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
|
||||
)
|
||||
single_file_pipe = StableDiffusionUpscalePipeline.from_single_file(ckpt_path, load_safety_checker=True)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.safety_checker.config.to_dict()[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
@@ -1091,3 +1091,39 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
|
||||
assert max_diff < 6e-3
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
single_file_pipe = StableDiffusionXLPipeline.from_single_file(
|
||||
ckpt_path, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
@@ -816,3 +816,35 @@ class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < 5e-2
|
||||
|
||||
def test_single_file_component_configs(self):
|
||||
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
|
||||
single_file_pipe = StableDiffusionXLImg2ImgPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
|
||||
assert pipe.text_encoder is None
|
||||
assert single_file_pipe.text_encoder is None
|
||||
|
||||
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
|
||||
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
|
||||
continue
|
||||
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
pipe.vae.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
Reference in New Issue
Block a user