mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
3 Commits
modular-di
...
fix-sd3-co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b56112db6e | ||
|
|
f50de75b69 | ||
|
|
579bb5f418 |
@@ -17,6 +17,7 @@ import argparse
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -52,6 +53,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.testing_utils import backend_empty_cache
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -74,8 +76,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
||||
|
||||
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
controlnet=controlnet,
|
||||
controlnet=None,
|
||||
safety_checker=None,
|
||||
transformer=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -102,18 +105,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
||||
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipeline.encode_prompt(
|
||||
validation_prompts,
|
||||
prompt_2=None,
|
||||
prompt_3=None,
|
||||
)
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
backend_empty_cache(accelerator.device.type)
|
||||
|
||||
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
text_encoder=None,
|
||||
text_encoder_2=None,
|
||||
text_encoder_3=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.enable_model_cpu_offload(device=accelerator.device.type)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
image_logs = []
|
||||
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
for i, validation_image in enumerate(validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
validation_prompt = validation_prompts[i]
|
||||
|
||||
images = []
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with inference_ctx:
|
||||
image = pipeline(
|
||||
validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
|
||||
prompt_embeds=prompt_embeds[i].unsqueeze(0),
|
||||
negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),
|
||||
pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),
|
||||
control_image=validation_image,
|
||||
num_inference_steps=20,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
|
||||
images.append(image)
|
||||
@@ -655,6 +695,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
|
||||
dataset = load_dataset(
|
||||
args.train_data_dir,
|
||||
cache_dir=args.cache_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||
|
||||
Reference in New Issue
Block a user