Compare commits

..

2 Commits

Author SHA1 Message Date
Dhruv Nair
858dfd6411 update 2023-12-06 12:14:35 +00:00
Dhruv Nair
6cb2178a91 Revert "fix"
This reverts commit f90a5139a2.
2023-12-06 06:44:02 +00:00
8 changed files with 40 additions and 67 deletions

View File

@@ -1,6 +1,6 @@
# Latent Consistency Distillation Example:
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps.
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference.
## Full model distillation
@@ -24,7 +24,7 @@ Then cd in the example folder and run
pip install -r requirements.txt
```
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
@@ -46,16 +46,12 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
#### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
#### Example with LAION-A6+ dataset
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
runwayml/stable-diffusion-v1-5
PROGRAM="train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
@@ -63,7 +59,7 @@ accelerate launch train_lcm_distill_sd_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
@@ -73,23 +69,19 @@ accelerate launch train_lcm_distill_sd_wds.py \
--resume_from_checkpoint=latest \
--report_to=wandb \
--seed=453645634 \
--push_to_hub
--push_to_hub \
```
## LCM-LoRA
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
### Example with LAION-A6+ dataset
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_lora_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
runwayml/stable-diffusion-v1-5
PROGRAM="train_lcm_distill_lora_sd_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
@@ -98,7 +90,7 @@ accelerate launch train_lcm_distill_lora_sd_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \

View File

@@ -1,6 +1,6 @@
# Latent Consistency Distillation Example:
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps.
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference.
## Full model distillation
@@ -24,7 +24,7 @@ Then cd in the example folder and run
pip install -r requirements.txt
```
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
@@ -46,16 +46,12 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
#### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
#### Example with LAION-A6+ dataset
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
PROGRAM="train_lcm_distill_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
@@ -64,7 +60,7 @@ accelerate launch train_lcm_distill_sdxl_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
@@ -81,15 +77,11 @@ accelerate launch train_lcm_distill_sdxl_wds.py \
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
### Example
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
### Example with LAION-A6+ dataset
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_lora_sdxl_wds.py \
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir=$OUTPUT_DIR \
@@ -100,7 +92,7 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \

View File

@@ -1123,7 +1123,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
image, text = batch
image, text, _, _ = batch
image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text)

View File

@@ -68,11 +68,6 @@ from diffusers.utils.import_utils import is_xformers_available
MAX_SEQ_LENGTH = 77
# Adjust for your dataset
WDS_JSON_WIDTH = "width" # original_width for LAION
WDS_JSON_HEIGHT = "height" # original_height for LAION
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
if is_wandb_available():
import wandb
@@ -151,10 +146,10 @@ class WebdatasetFilter:
try:
if "json" in x:
x_json = json.loads(x["json"])
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
WDS_JSON_HEIGHT, 0
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
"original_height", 0
) >= self.min_size
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
return filter_size and filter_watermark
else:
return False
@@ -185,7 +180,7 @@ class Text2ImageDataset:
if use_fix_crop_and_size:
return (resolution, resolution)
else:
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
def transform(example):
# resize image
@@ -217,7 +212,7 @@ class Text2ImageDataset:
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow,
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
wds.select(WebdatasetFilter(min_size=960)),
wds.shuffle(shuffle_buffer_size),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),

View File

@@ -1106,7 +1106,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
image, text = batch
image, text, _, _ = batch
image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text)

View File

@@ -67,11 +67,6 @@ from diffusers.utils.import_utils import is_xformers_available
MAX_SEQ_LENGTH = 77
# Adjust for your dataset
WDS_JSON_WIDTH = "width" # original_width for LAION
WDS_JSON_HEIGHT = "height" # original_height for LAION
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
if is_wandb_available():
import wandb
@@ -133,10 +128,10 @@ class WebdatasetFilter:
try:
if "json" in x:
x_json = json.loads(x["json"])
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
WDS_JSON_HEIGHT, 0
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
"original_height", 0
) >= self.min_size
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
return filter_size and filter_watermark
else:
return False
@@ -167,7 +162,7 @@ class Text2ImageDataset:
if use_fix_crop_and_size:
return (resolution, resolution)
else:
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
def transform(example):
# resize image
@@ -199,7 +194,7 @@ class Text2ImageDataset:
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow,
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
wds.select(WebdatasetFilter(min_size=960)),
wds.shuffle(shuffle_buffer_size),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),

View File

@@ -446,9 +446,8 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
# Relevant to StableDiffusionUpscalePipeline
if "num_class_embeds" in config:
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]

View File