mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-25 13:54:45 +08:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5311f564ed | ||
|
|
3b7f514a1c | ||
|
|
7c0a861894 | ||
|
|
a73ae3e5b0 | ||
|
|
06505ba4b4 | ||
|
|
13457002c0 | ||
|
|
302b86bd0b | ||
|
|
d87d5edf66 | ||
|
|
e795a4c6f8 | ||
|
|
4293b9f54f |
36
README.md
36
README.md
@@ -64,7 +64,12 @@ The class provides functionality to compute previous image according to alpha, b
|
||||
|
||||
## Quickstart
|
||||
|
||||
**Check out this notebook: https://colab.research.google.com/drive/1nMfF04cIxg6FujxsNYi9kiTRrzj4_eZU?usp=sharing**
|
||||
In order to get started, we recommend taking a look at two notebooks:
|
||||
|
||||
- The [Diffusers](https://colab.research.google.com/drive/1aEFVu0CvcIBzSNIQ7F71ujYYplAX4Bml?usp=sharing#scrollTo=PzW5ublpBuUt) notebook, which showcases an end-to-end example of usage for diffusion models, schedulers and pipelines.
|
||||
Take a look at this notebook to learn how to use the pipeline abstraction, which takes care of everything (model, scheduler, noise handling) for you, but also to get an understanding of each independent building blocks in the library.
|
||||
- The [Training diffusers](https://colab.research.google.com/drive/1qqJmz7JJ04suJzEF4Hn4-Acb8rfL-eA3?usp=sharing) notebook, which summarizes diffuser model training methods. This notebook takes a step-by-step approach to training your
|
||||
diffuser model on an image dataset, with explanatory graphics.
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -133,3 +138,32 @@ image_pil.save("generated_image.png")
|
||||
|
||||
```python
|
||||
```
|
||||
|
||||
|
||||
## In the works
|
||||
|
||||
For the first release, 🤗 Diffusers focuses on text-to-image diffusion techniques. However, diffusers can be used for much more than that! Over the upcoming releases, we'll be focusing on:
|
||||
|
||||
- Diffusers for audio
|
||||
- Diffusers for reinforcement learning (initial work happening in https://github.com/huggingface/diffusers/pull/105).
|
||||
- Diffusers for video generation
|
||||
- Diffusers for molecule generation (initial work happening in https://github.com/huggingface/diffusers/pull/54)
|
||||
|
||||
A few pipeline components are already being worked on, namely:
|
||||
|
||||
- BDDMPipeline for spectrogram-to-sound vocoding
|
||||
- GLIDEPipeline to support OpenAI's GLIDE model
|
||||
- Grad-TTS for text to audio generation / conditional audio generation
|
||||
|
||||
We want diffusers to be a toolbox useful for diffusers models in general; if you find yourself limited in any way by the current API, or would like to see additional models, schedulers, or techniques, please open a [GitHub issue](https://github.com/huggingface/diffusers/issues) mentioning what you would like to see.
|
||||
|
||||
## Credits
|
||||
|
||||
This library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today:
|
||||
|
||||
- @CompVis' latent diffusion models library, available [here](https://github.com/CompVis/latent-diffusion)
|
||||
- @hojonathanho original DDPM implementation, available [here](https://github.com/hojonathanho/diffusion) as well as the extremely useful translation into PyTorch by @pesser, available [here](https://github.com/pesser/pytorch_diffusion)
|
||||
- @ermongroup's DDIM implementation, available [here](https://github.com/ermongroup/ddim).
|
||||
- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch)
|
||||
|
||||
We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models).
|
||||
@@ -30,4 +30,4 @@ with a `set_format(...)` method.
|
||||
- The ['DDPMScheduler'] was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py).
|
||||
An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
|
||||
- The ['DDIMScheduler'] was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
|
||||
- The ['PNMDScheduler'] was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
- The ['PNDMScheduler'] was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
@@ -1,5 +1,13 @@
|
||||
## Training examples
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scipts, make sure to install the library's training dependencies:
|
||||
|
||||
```bash
|
||||
pip install diffusers[training] accelerate datasets
|
||||
```
|
||||
|
||||
### Unconditional Flowers
|
||||
|
||||
The command to train a DDPM UNet model on the Oxford Flowers dataset:
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import PIL.Image
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMScheduler, Glide, GlideUNetModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import logging
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
InterpolationMode,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
accelerator = Accelerator(mixed_precision=args.mixed_precision)
|
||||
|
||||
pipeline = Glide.from_pretrained("fusing/glide-base")
|
||||
model = pipeline.text_unet
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
|
||||
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
|
||||
|
||||
augmentations = Compose(
|
||||
[
|
||||
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
|
||||
CenterCrop(args.resolution),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
|
||||
text_encoder = pipeline.text_encoder.eval()
|
||||
|
||||
def transforms(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_inputs = text_inputs.input_ids.to(accelerator.device)
|
||||
with torch.no_grad():
|
||||
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state
|
||||
return {"images": images, "text_embeddings": text_embeddings}
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
"linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.warmup_steps,
|
||||
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo = init_git_repo(args, at_init=True)
|
||||
|
||||
# Train!
|
||||
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size() if is_distributed else 1
|
||||
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
|
||||
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
for epoch in range(args.num_epochs):
|
||||
model.train()
|
||||
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
|
||||
pbar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_images = batch["images"]
|
||||
batch_size, n_channels, height, width = clean_images.shape
|
||||
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device
|
||||
).long()
|
||||
|
||||
# add noise onto the clean images according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
|
||||
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
with accelerator.no_sync(model):
|
||||
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
|
||||
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
|
||||
# Learn the variance using the variational bound, but don't let
|
||||
# it affect our mean prediction.
|
||||
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(model_output, noise_samples)
|
||||
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
else:
|
||||
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
|
||||
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
|
||||
# Learn the variance using the variational bound, but don't let
|
||||
# it affect our mean prediction.
|
||||
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(model_output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Generate a sample image for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
pipeline.unet = accelerator.unwrap_model(model)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50)
|
||||
|
||||
# process image to PIL
|
||||
image_processed = image.squeeze(0)
|
||||
image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
image_pil = PIL.Image.fromarray(image_processed)
|
||||
|
||||
# save image
|
||||
test_dir = os.path.join(args.output_dir, "test_samples")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
image_pil.save(f"{test_dir}/{epoch:04d}.png")
|
||||
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
|
||||
parser.add_argument("--output_dir", type=str, default="glide-text2image")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument("--resolution", type=int, default=64)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--warmup_steps", type=int, default=500)
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
parser.add_argument("--hub_token", type=str, default=None)
|
||||
parser.add_argument("--hub_model_id", type=str, default=None)
|
||||
parser.add_argument("--hub_private_repo", action="store_true")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
main(args)
|
||||
@@ -1,216 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import PIL.Image
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMScheduler, LatentDiffusion, UNetLDMModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import logging
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
InterpolationMode,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
accelerator = Accelerator(mixed_precision=args.mixed_precision)
|
||||
|
||||
pipeline = LatentDiffusion.from_pretrained("fusing/latent-diffusion-text2im-large")
|
||||
pipeline.unet = None # this model will be trained from scratch now
|
||||
model = UNetLDMModel(
|
||||
attention_resolutions=[4, 2, 1],
|
||||
channel_mult=[1, 2, 4, 4],
|
||||
context_dim=1280,
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
dropout=0,
|
||||
image_size=8,
|
||||
in_channels=4,
|
||||
model_channels=320,
|
||||
num_heads=8,
|
||||
num_res_blocks=2,
|
||||
out_channels=4,
|
||||
resblock_updown=False,
|
||||
transformer_depth=1,
|
||||
use_new_attention_order=False,
|
||||
use_scale_shift_norm=False,
|
||||
use_spatial_transformer=True,
|
||||
legacy=False,
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
|
||||
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
|
||||
|
||||
augmentations = Compose(
|
||||
[
|
||||
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
|
||||
CenterCrop(args.resolution),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
|
||||
text_encoder = pipeline.bert.eval()
|
||||
vqvae = pipeline.vqvae.eval()
|
||||
|
||||
def transforms(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs.input_ids.cpu()).last_hidden_state
|
||||
images = 1 / 0.18215 * torch.stack(images, dim=0)
|
||||
latents = accelerator.unwrap_model(vqvae).encode(images.cpu()).mode()
|
||||
return {"images": images, "text_embeddings": text_embeddings, "latents": latents}
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
"linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.warmup_steps,
|
||||
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoder = text_encoder.cpu()
|
||||
vqvae = vqvae.cpu()
|
||||
|
||||
if args.push_to_hub:
|
||||
repo = init_git_repo(args, at_init=True)
|
||||
|
||||
# Train!
|
||||
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size() if is_distributed else 1
|
||||
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
|
||||
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(args.num_epochs):
|
||||
model.train()
|
||||
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
|
||||
pbar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_latents = batch["latents"]
|
||||
noise_samples = torch.randn(clean_latents.shape).to(clean_latents.device)
|
||||
bsz = clean_latents.shape[0]
|
||||
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_latents.device).long()
|
||||
|
||||
# add noise onto the clean latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.training_step(clean_latents, noise_samples, timesteps)
|
||||
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
with accelerator.no_sync(model):
|
||||
output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
else:
|
||||
output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
|
||||
global_step += 1
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Generate a sample image for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
pipeline.unet = accelerator.unwrap_model(model)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = pipeline(
|
||||
["a clip art of a corgi"], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50
|
||||
)
|
||||
|
||||
# process image to PIL
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = image_processed * 255.0
|
||||
image_processed = image_processed.type(torch.uint8).numpy()
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
test_dir = os.path.join(args.output_dir, "test_samples")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
image_pil.save(f"{test_dir}/{epoch:04d}.png")
|
||||
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
|
||||
parser.add_argument("--output_dir", type=str, default="ldm-text2image")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument("--resolution", type=int, default=128)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--warmup_steps", type=int, default=500)
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
parser.add_argument("--hub_token", type=str, default=None)
|
||||
parser.add_argument("--hub_model_id", type=str, default=None)
|
||||
parser.add_argument("--hub_private_repo", action="store_true")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
main(args)
|
||||
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNetUnconditionalModel
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
@@ -34,27 +34,27 @@ def main(args):
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
model = UNetUnconditionalModel(
|
||||
image_size=args.resolution,
|
||||
model = UNet2DModel(
|
||||
sample_size=args.resolution,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
num_res_blocks=2,
|
||||
block_channels=(128, 128, 256, 256, 512, 512),
|
||||
down_blocks=(
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
layers_per_block=2,
|
||||
block_out_channels=(128, 128, 256, 256, 512, 512),
|
||||
down_block_types=(
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_blocks=(
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResAttnUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
up_block_types=(
|
||||
"UpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
|
||||
@@ -147,9 +147,9 @@ def main(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Generate a sample image for visual inspection
|
||||
# Generate sample images for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
with torch.no_grad():
|
||||
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
pipeline = DDPMPipeline(
|
||||
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
|
||||
scheduler=noise_scheduler,
|
||||
@@ -157,13 +157,13 @@ def main(args):
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images.cpu() + 1.0) * 127.5
|
||||
images_processed = images_processed.clamp(0, 255).type(torch.uint8).numpy()
|
||||
|
||||
accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
accelerator.trackers[0].writer.add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
@@ -186,7 +186,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--save_model_epochs", type=int, default=5)
|
||||
parser.add_argument("--save_images_epochs", type=int, default=10)
|
||||
parser.add_argument("--save_model_epochs", type=int, default=10)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||
@@ -194,7 +195,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.95)
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-3)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
||||
parser.add_argument("--use_ema", action="store_true", default=True)
|
||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
||||
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
||||
|
||||
4
setup.py
4
setup.py
@@ -81,6 +81,7 @@ _deps = [
|
||||
"filelock",
|
||||
"flake8>=3.8.3",
|
||||
"huggingface-hub",
|
||||
"importlib_metadata",
|
||||
"isort>=5.5.4",
|
||||
"numpy",
|
||||
"pytest",
|
||||
@@ -168,6 +169,7 @@ extras["test"] = [
|
||||
extras["dev"] = extras["quality"] + extras["test"] + extras["training"]
|
||||
|
||||
install_requires = [
|
||||
deps["importlib_metadata"],
|
||||
deps["filelock"],
|
||||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
@@ -179,7 +181,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.1.0",
|
||||
version="0.1.1",
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.1"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
|
||||
@@ -7,6 +7,7 @@ deps = {
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"huggingface-hub": "huggingface-hub",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"isort": "isort>=5.5.4",
|
||||
"numpy": "numpy",
|
||||
"pytest": "pytest",
|
||||
|
||||
@@ -21,14 +21,13 @@ from typing import Optional
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from utils import is_modelcards_available
|
||||
|
||||
from .utils import is_modelcards_available, logging
|
||||
|
||||
|
||||
if is_modelcards_available():
|
||||
from modelcards import CardData, ModelCard
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -6,31 +6,33 @@ from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
|
||||
|
||||
img_size = self.model.config.sample_size
|
||||
shape = (1, 3, img_size, img_size)
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
model = self.model.to(device)
|
||||
img_size = self.unet.config.sample_size
|
||||
shape = (batch_size, 3, img_size, img_size)
|
||||
|
||||
model = self.unet.to(torch_device)
|
||||
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(device)
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
|
||||
|
||||
# correction step
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
model_output = self.model(sample, sigma_t)["sample"]
|
||||
model_output = self.unet(sample, sigma_t)["sample"]
|
||||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
|
||||
|
||||
# prediction step
|
||||
@@ -39,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
sample = sample.clamp(0, 1)
|
||||
sample = sample_mean.clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
|
||||
@@ -15,4 +15,4 @@ with a `set_format(...)` method.
|
||||
|
||||
- The DDPM scheduler was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py). An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
|
||||
- The DDIM scheduler was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
|
||||
- The PNMD scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
- The PNDM scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
|
||||
@@ -848,15 +848,12 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model = UNet2DModel.from_pretrained("google/ncsnpp-church-256")
|
||||
model_id = "google/ncsnpp-church-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
||||
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
|
||||
|
||||
Reference in New Issue
Block a user