mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-30 23:45:01 +08:00
Compare commits
3 Commits
modular-lo
...
fix-sd3-co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b56112db6e | ||
|
|
f50de75b69 | ||
|
|
579bb5f418 |
@@ -17,6 +17,7 @@ import argparse
|
|||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -52,6 +53,7 @@ from diffusers.optimization import get_scheduler
|
|||||||
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
||||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||||
|
from diffusers.utils.testing_utils import backend_empty_cache
|
||||||
from diffusers.utils.torch_utils import is_compiled_module
|
from diffusers.utils.torch_utils import is_compiled_module
|
||||||
|
|
||||||
|
|
||||||
@@ -74,8 +76,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
|||||||
|
|
||||||
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
|
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
controlnet=controlnet,
|
controlnet=None,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
transformer=None,
|
||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
variant=args.variant,
|
variant=args.variant,
|
||||||
torch_dtype=weight_dtype,
|
torch_dtype=weight_dtype,
|
||||||
@@ -102,18 +105,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
|||||||
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
|
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds,
|
||||||
|
) = pipeline.encode_prompt(
|
||||||
|
validation_prompts,
|
||||||
|
prompt_2=None,
|
||||||
|
prompt_3=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
del pipeline
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(accelerator.device.type)
|
||||||
|
|
||||||
|
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||||
|
args.pretrained_model_name_or_path,
|
||||||
|
controlnet=controlnet,
|
||||||
|
safety_checker=None,
|
||||||
|
text_encoder=None,
|
||||||
|
text_encoder_2=None,
|
||||||
|
text_encoder_3=None,
|
||||||
|
revision=args.revision,
|
||||||
|
variant=args.variant,
|
||||||
|
torch_dtype=weight_dtype,
|
||||||
|
)
|
||||||
|
pipeline.enable_model_cpu_offload(device=accelerator.device.type)
|
||||||
|
pipeline.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
image_logs = []
|
image_logs = []
|
||||||
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
|
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
|
||||||
|
|
||||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
for i, validation_image in enumerate(validation_images):
|
||||||
validation_image = Image.open(validation_image).convert("RGB")
|
validation_image = Image.open(validation_image).convert("RGB")
|
||||||
|
validation_prompt = validation_prompts[i]
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
for _ in range(args.num_validation_images):
|
for _ in range(args.num_validation_images):
|
||||||
with inference_ctx:
|
with inference_ctx:
|
||||||
image = pipeline(
|
image = pipeline(
|
||||||
validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
|
prompt_embeds=prompt_embeds[i].unsqueeze(0),
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),
|
||||||
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),
|
||||||
|
control_image=validation_image,
|
||||||
|
num_inference_steps=20,
|
||||||
|
generator=generator,
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
images.append(image)
|
images.append(image)
|
||||||
@@ -655,6 +695,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
|
|||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
args.train_data_dir,
|
args.train_data_dir,
|
||||||
cache_dir=args.cache_dir,
|
cache_dir=args.cache_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
# See more about loading custom images at
|
# See more about loading custom images at
|
||||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||||
|
|||||||
Reference in New Issue
Block a user