Compare commits

...

89 Commits

Author SHA1 Message Date
Dhruv Nair
9c7e2f6721 Merge branch 'pipeline-interrupt' of https://github.com/huggingface/diffusers into pipeline-interrupt 2023-12-26 05:58:58 +00:00
Dhruv Nair
59a1524aad update 2023-12-26 05:58:52 +00:00
Sayak Paul
e23f6051a1 Merge branch 'main' into pipeline-interrupt 2023-12-26 08:56:18 +05:30
dg845
a3d31e3a3e Change LCM-LoRA README Script Example Learning Rates to 1e-4 (#6304)
Change README LCM-LoRA example learning rates to 1e-4.
2023-12-25 21:29:20 +05:30
Jianqi Pan
84c403aedb fix: cannot set guidance_scale (#6326)
fix: set guidance_scale
2023-12-25 21:16:57 +05:30
Sayak Paul
f4b0b26f7e [Tests] Speed up example tests (#6319)
* remove validation args from textual onverson tests

* reduce number of train steps in textual inversion tests

* fix: directories.

* debig

* fix: directories.

* remove validation tests from textual onversion

* try reducing the time of test_text_to_image_checkpointing_use_ema

* fix: directories

* speed up test_text_to_image_checkpointing

* speed up test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* fix

* speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* set checkpoints_total_limit to 2.

* test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints speed up

* speed up test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* debug

* fix: directories.

* speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit

* speed up: test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* speed up test_controlnet_sdxl

* speed up dreambooth tests

* speed up test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* speed up test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* speed up test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit

* speed up # checkpoint-2 should have been deleted

* speed up examples/text_to_image/test_text_to_image.py::TextToImage::test_text_to_image_checkpointing_checkpoints_total_limit

* additional speed ups

* style
2023-12-25 19:50:48 +05:30
Dhruv Nair
6648af6d83 fix 2023-12-25 13:12:12 +00:00
Sayak Paul
a894d9f921 Merge branch 'main' into pipeline-interrupt 2023-12-25 12:02:07 +05:30
Sayak Paul
89459a5d56 fix: lora peft dummy components (#6308)
* fix: lora peft dummy components

* fix: dummy components
2023-12-25 11:26:45 +05:30
Sayak Paul
008d9818a2 fix: t2i apdater paper link (#6314) 2023-12-25 10:45:14 +05:30
mwkldeveloper
2d43094ffc fix RuntimeError: Input type (float) and bias type (c10::Half) should be the same in train_text_to_image_lora.py (#6259)
* fix RuntimeError: Input type (float) and bias type (c10::Half) should be the same

* format source code

* format code

* remove the autocast blocks within the pipeline

* add autocast blocks to pipeline caller in train_text_to_image_lora.py
2023-12-24 14:34:35 +05:30
Celestial Phineas
7c05b975b7 Fix typos in the ValueError for a nested image list as StableDiffusionControlNetPipeline input. (#6286)
Fixed typos in the `ValueError` for a nested image list as input.
2023-12-24 14:32:24 +05:30
Dhruv Nair
fe574c8b29 LoRA Unfusion test fix (#6291)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-24 14:31:48 +05:30
Sayak Paul
90b9479903 [LoRA PEFT] fix LoRA loading so that correct alphas are parsed (#6225)
* initialize alpha too.

* add: test

* remove config parsing

* store rank

* debug

* remove faulty test
2023-12-24 09:59:41 +05:30
apolinário
df76a39e1b Fix Prodigy optimizer in SDXL Dreambooth script (#6290)
* Fix ProdigyOPT in SDXL Dreambooth script

* style

* style
2023-12-22 06:42:04 -06:00
Bingxin Ke
3369bc810a [Community Pipeline] Add Marigold Monocular Depth Estimation (#6249)
* [Community Pipeline] Add Marigold Monocular Depth Estimation

- add single-file pipeline
- update README

* fix format - add one blank line

* format script with ruff

* use direct image link in example code

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-22 15:41:46 +05:30
Pedro Cuenca
7fe47596af Allow diffusers to load with Flax, w/o PyTorch (#6272) 2023-12-22 09:37:30 +01:00
Dhruv Nair
59d1caa238 Remove peft tests from old lora backend tests (#6273)
update
2023-12-22 13:35:52 +05:30
Dhruv Nair
c022e52923 Remove ONNX inpaint legacy (#6269)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-22 13:35:21 +05:30
Will Berman
4039815276 open muse (#5437)
amused

rename

Update docs/source/en/api/pipelines/amused.md

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

AdaLayerNormContinuous default values

custom micro conditioning

micro conditioning docs

put lookup from codebook in constructor

fix conversion script

remove manual fused flash attn kernel

add training script

temp remove training script

add dummy gradient checkpointing func

clarify temperatures is an instance variable by setting it

remove additional SkipFF block args

hardcode norm args

rename tests folder

fix paths and samples

fix tests

add training script

training readme

lora saving and loading

non-lora saving/loading

some readme fixes

guards

Update docs/source/en/api/pipelines/amused.md

Co-authored-by: Suraj Patil <surajp815@gmail.com>

Update examples/amused/README.md

Co-authored-by: Suraj Patil <surajp815@gmail.com>

Update examples/amused/train_amused.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

vae upcasting

add fp16 integration tests

use tuple for micro cond

copyrights

remove casts

delegate to torch.nn.LayerNorm

move temperature to pipeline call

upsampling/downsampling changes
2023-12-21 11:40:55 -08:00
Sayak Paul
5b186b7128 [Refactor] move ldm3d out of stable_diffusion. (#6263)
ldm3d.
2023-12-21 18:59:55 +05:30
Sayak Paul
ab0459f2b7 [Deprecated pipelines] remove pix2pix zero from init (#6268)
remove pix2pix zero from init
2023-12-21 18:17:28 +05:30
Sayak Paul
9c7cc36011 [Refactor] move panorama out of stable_diffusion (#6262)
* move panorama out.

* fix: diffedit

* fix: import.

* fix: impirt
2023-12-21 18:17:05 +05:30
Sayak Paul
325f6c53ed [Refactor] move attend and excite out of stable_diffusion. (#6261)
* move attend and excite out.

* fix: import

* fix diffedit
2023-12-21 16:49:32 +05:30
Benjamin Bossan
43979c2890 TST Fix LoRA test that fails with PEFT >= 0.7.0 (#6216)
See #6185 for context.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-21 11:50:05 +01:00
Sayak Paul
9ea6ac1b07 [Refactor] move sag out of stable_diffusion (#6264)
move sag out of .
2023-12-21 16:09:49 +05:30
Sayak Paul
2c34c7d6dd [Refactor] move gligen out of stable diffusion. (#6265)
* move gligen out of stable diffusion.

* fix: import

* fix import module
2023-12-21 15:26:52 +05:30
Sayak Paul
bffadde126 [Refactor] move k diffusion out of stable_diffusion (#6267)
move k diffusion out of stable_diffusion
2023-12-21 15:24:24 +05:30
YShow
35a969d297 [Training] remove depcreated method from lora scripts again (#6266)
* remove depcreated method from lora scripts

* check code quality
2023-12-21 14:17:52 +05:30
sayakpaul
c5ff469d0e Revert "move attend and excite out of stable_diffusion"
This reverts commit bcecfbc873.
2023-12-21 12:35:58 +05:30
sayakpaul
bcecfbc873 move attend and excite out of stable_diffusion 2023-12-21 12:35:09 +05:30
Sayak Paul
6269045c5b [Refactor] move diffedit out of stable_diffusion (#6260)
* move diffedit out of stable_diffuson

* fix: import

* style

* fix: import
2023-12-21 12:26:36 +05:30
lvzi
6ca9c4af05 fix: unscale fp16 gradient problem & potential error (#6086) (#6231)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-21 09:09:26 +05:30
dependabot[bot]
0532cece97 Bump transformers from 4.34.0 to 4.36.0 in /examples/research_projects/realfill (#6255)
Bump transformers in /examples/research_projects/realfill

Bumps [transformers](https://github.com/huggingface/transformers) from 4.34.0 to 4.36.0.
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](https://github.com/huggingface/transformers/compare/v4.34.0...v4.36.0)

---
updated-dependencies:
- dependency-name: transformers
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-12-21 09:03:17 +05:30
Sayak Paul
22b45304bf [Refactor upsamplers and downsamplers] separate out upsamplers and downsamplers. (#6128)
* separate out upsamplers and downsamplers.

* import all the necessary blocks in resnet for backward comp.

* move upsample2d and downsample2d to utils.

* move downsample_2d to downsamplers.py

* apply feedback

* fix import

* samplers -> sampling
2023-12-20 21:01:33 +05:30
Beinsezii
457abdf2cf EulerAncestral add rescale_betas_zero_snr (#6187)
* EulerAncestral add `rescale_betas_zero_snr`

Uses same infinite sigma fix from EulerDiscrete. Interestingly the
ancestral version had the opposite problem: too much contrast instead of
too little.

* UT for EulerAncestral `rescale_betas_zero_snr`

* EulerAncestral upcast samples during step()

It helps this scheduler too, particularly when the model is using bf16.

While the noise dtype is still the model's it's automatically upcasted
for the add so all it affects is determinism.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-20 13:09:25 +05:30
hako-mikan
ff43dba7ea [Fix] Fix Regional Prompting Pipeline (#6188)
* Update regional_prompting_stable_diffusion.py

* reformat

* reformat

* reformat

* reformat

* reformat

* reformat

* reformat

* regormat

* reformat

* reformat

* reformat

* reformat

* Update regional_prompting_stable_diffusion.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-20 10:37:19 +05:30
Sayak Paul
5fedea920f Merge branch 'main' into pipeline-interrupt 2023-12-20 10:05:00 +05:30
Steven Liu
5433962992 [docs] Batched seeds (#6237)
batched seed
2023-12-19 16:50:18 -08:00
raven
df476d9f63 [Docs] Fix a code example in the ControlNet Inpainting documentation (#6236)
fix document on masked image in inpainting controlnet
2023-12-19 13:14:37 -08:00
YiYi Xu
3e71a20650 [refactor embeddings]pixart-alpha (#6212)
pixart-alpha

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2023-12-19 07:07:24 -10:00
Sayak Paul
bf40d7d82a add peft dependency to fast push tests (#6229)
* add peft dependency

* add peft dependency at the correct place.
2023-12-19 13:26:25 +05:30
Dhruv Nair
32ff4773d4 ControlNetXS fixes. (#6228)
update
2023-12-19 11:58:34 +05:30
Sayak Paul
288ceebea5 [T2I LoRA training] fix: unscale fp16 gradient problem (#6119)
* fix: unscale fp16 gradient problem

* fix for dreambooth lora sdxl

* make the type-casting conditional.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-19 09:54:17 +05:30
Sayak Paul
9221da4063 fix: init for vae during pixart tests (#6215)
* fix: init for vae during pixart tests

* print the values

* add flatten

* correct assertion value for test_inference

* correct assertion values for test_inference_non_square_images

* run styling

* debug test_inference_with_multiple_images_per_prompt

* fix assertion values for test_inference_with_multiple_images_per_prompt
2023-12-18 18:16:57 -10:00
Dhruv Nair
852024a34a Merge branch 'main' into pipeline-interrupt 2023-12-19 04:00:14 +00:00
YiYi Xu
57fde871e1 offload the optional module image_encoder (#6151)
* offload image_encoder

* add test

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-18 15:10:01 -10:00
Fabio Rigano
68e962395c Add converter method for ip adapters (#6150)
* Add converter method for ip adapters

* Move converter method

* Update to image proj converter

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-18 13:46:43 -10:00
Dhruv Nair
781775ea56 Slow Test for Pipelines minor fixes (#6221)
update
2023-12-19 00:45:51 +05:30
Patrick von Platen
fa3c86beaf [SVD] Fix guidance scale (#6002)
* [SVD] Fix guidance scale

* make style
2023-12-18 19:33:24 +01:00
Haofan Wang
7d0a47f387 Update train_text_to_image_lora.py (#6144)
* Update train_text_to_image_lora.py

* Fix typo?

---------

Co-authored-by: M. Tolga Cangöz <46008593+standardAI@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-18 19:33:05 +01:00
Aryan V S
67b3d3267e Support img2img and inpaint in lpw-xl (#6114)
* add img2img and inpaint support to lpw-xl

* update community README

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-18 19:19:11 +01:00
TilmannR
4e77056885 Update README.md (#6191)
Typo: The script for LoRA training is `train_text_to_image_lora_prior.py` not `train_text_to_image_prior_lora.py`.

Alternatively you could rename the file and keep the README.md unchanged.
2023-12-18 19:08:29 +01:00
Dhruv Nair
a0c54828a1 Deprecate Pipelines (#6169)
* deprecate pipe

* make style

* update

* add deprecation message

* format

* remove tests for deprecated pipelines

* remove deprecation message

* make style

* fix copies

* clean up

* clean

* clean

* clean

* clean up

* clean up

* clean up toctree

* clean up

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-18 23:08:29 +05:30
Patrick von Platen
8d891e6e1b [Torch Compile] Fix torch compile for svd vae (#6217) 2023-12-18 18:21:17 +01:00
Patrick von Platen
cce1fe2d41 [Text-to-Video] Clean up pipeline (#6213)
* make style

* make style

* make style

* make style
2023-12-18 18:21:09 +01:00
Abin Thomas
d816bcb5e8 Fix t2i. blog url (#6205) 2023-12-18 09:12:28 -08:00
d8ahazard
6976cab7ca Fix possible re-conversion issues after extracting from safetensors (#6097)
* Fix possible re-conversion issues after extracting from diffusers

Properly rename specific vae keys.

* Whoops
2023-12-18 11:51:20 +01:00
Dhruv Nair
fcbed3fa79 Fix SDXL Inpainting from single file with Refiner Model (#6147)
* update

* update

* update
2023-12-18 11:45:37 +01:00
Sayak Paul
b98b314b7a [Training] remove depcreated method from lora scripts. (#6207)
remove depcreated method from lora scripts.
2023-12-18 15:52:43 +05:30
Omar Sanseviero
74558ff65b Nit fix to training params (#6200) 2023-12-18 11:06:16 +01:00
Yudong Jin
49644babd3 Fix the test script in examples/text_to_image/README.md (#6209)
* Update examples/text_to_image/README.md

* Update examples/text_to_image/README.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-18 15:36:00 +05:30
Sayak Paul
56b3b21693 [Refactor autoencoders] feat: introduce autoencoders module (#6129)
* feat: introduce autoencoders module

* more changes for styling and copy fixing

* path changes in the docs.

* fix: import structure in init.

* fix controlnetxs import
2023-12-18 12:42:15 +05:30
Sayak Paul
9cef07da5a [Benchmarks] fix: lcm benchmarking reporting (#6198)
* fix: lcm benchmarking reporting

* fix generate_csv_dict call.
2023-12-17 15:32:11 +05:30
Sayak Paul
2d94c7838e [Core] feat: enable fused attention projections for other SD and SDXL pipelines (#6179)
* feat: enable fused attention projections for other SD and SDXL pipelines

* add: test for SD fused projections.
2023-12-16 08:45:54 +05:30
Sayak Paul
a81334e3f0 [LoRA] add an error message when dealing with _best_guess_weight_name ofline (#6184)
* add an error message when dealing with _best_guess_weight_name ofline

* simplify condition
2023-12-16 08:36:08 +05:30
Dhruv Nair
2f2775d5f0 Merge branch 'pipeline-interrupt' of https://github.com/huggingface/diffusers into pipeline-interrupt 2023-12-04 15:31:41 +00:00
Dhruv Nair
52c45764ea fix quality issues 2023-12-04 15:30:53 +00:00
Dhruv Nair
e8bf891380 Merge branch 'main' into pipeline-interrupt 2023-12-04 12:58:31 +00:00
Patrick von Platen
fa213518ee Merge branch 'main' into pipeline-interrupt 2023-12-04 12:06:18 +01:00
Dhruv Nair
aefc1d72df Merge branch 'pipeline-interrupt' of https://github.com/huggingface/diffusers into pipeline-interrupt 2023-12-04 07:31:02 +00:00
Dhruv Nair
2e31ef21c4 Merge branch 'main' into pipeline-interrupt 2023-12-04 07:30:33 +00:00
Patrick von Platen
69f68d34ab Merge branch 'main' into pipeline-interrupt 2023-12-01 16:51:08 +01:00
Dhruv Nair
f4f3aed0a7 Merge branch 'main' into pipeline-interrupt 2023-11-30 17:47:55 +05:30
Dhruv Nair
dacaacc054 update 2023-11-30 10:38:00 +00:00
Dhruv Nair
5d415e970d Update docs/source/en/tutorials/interrupting_diffusion_process.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-11-30 15:49:37 +05:30
Dhruv Nair
f608d21779 Update docs/source/en/tutorials/interrupting_diffusion_process.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-11-30 15:49:29 +05:30
Dhruv Nair
3d69fd0087 add tutorial 2023-11-22 12:25:04 +00:00
Dhruv Nair
832f487a61 Merge branch 'main' into pipeline-interrupt 2023-11-22 11:38:00 +00:00
Dhruv Nair
36b4de2e21 add docs 2023-11-22 11:12:21 +00:00
Dhruv Nair
b693e254d7 Revert "make fix copies"
This reverts commit 914b35332b.
2023-11-21 12:23:46 +00:00
Dhruv Nair
914b35332b make fix copies 2023-11-21 12:07:46 +00:00
Dhruv Nair
fb832f7fdc Merge branch 'pipeline-interrupt' of https://github.com/huggingface/diffusers into pipeline-interrupt 2023-11-21 12:02:28 +00:00
Dhruv Nair
717cb97b83 add interrupt property 2023-11-21 12:01:17 +00:00
Dhruv Nair
6e61b0fb79 updatemsmq 2023-11-21 12:00:32 +00:00
Dhruv Nair
78201dd10f Merge branch 'main' into pipeline-interrupt 2023-11-21 11:33:23 +00:00
Dhruv Nair
b9a49ccfe5 Merge branch 'main' into pipeline-interrupt 2023-11-20 20:56:31 +05:30
Dhruv Nair
ac92d8513c add tests 2023-11-20 11:55:40 +00:00
Dhruv Nair
4683961b53 add interruptable pipelines 2023-11-20 10:23:20 +00:00
202 changed files with 10409 additions and 4762 deletions

View File

@@ -98,6 +98,7 @@ jobs:
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m pip install peft
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples

View File

@@ -162,6 +162,25 @@ class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
guidance_scale=1.0,
)
def benchmark(self, args):
flush()
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
benchmark_info = BenchmarkInfo(time=time, memory=memory)
pipeline_class_name = str(self.pipe.__class__.__name__)
flush()
csv_dict = generate_csv_dict(
pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
)
filepath = self.get_result_filepath(args)
write_to_csv(filepath, csv_dict)
print(f"Logs written to: {filepath}")
flush()
class ImageToImageBenchmark(TextToImageBenchmark):
pipeline_class = AutoPipelineForImage2Image

View File

@@ -244,14 +244,12 @@
- sections:
- local: api/pipelines/overview
title: Overview
- local: api/pipelines/alt_diffusion
title: AltDiffusion
- local: api/pipelines/amused
title: aMUSEd
- local: api/pipelines/animatediff
title: AnimateDiff
- local: api/pipelines/attend_and_excite
title: Attend-and-Excite
- local: api/pipelines/audio_diffusion
title: Audio Diffusion
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/audioldm2
@@ -270,8 +268,6 @@
title: ControlNet-XS
- local: api/pipelines/controlnetxs_sdxl
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/cycle_diffusion
title: Cycle Diffusion
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
@@ -302,26 +298,14 @@
title: MusicLDM
- local: api/pipelines/paint_by_example
title: Paint by Example
- local: api/pipelines/paradigms
title: Parallel Sampling of Diffusion Models
- local: api/pipelines/pix2pix_zero
title: Pix2Pix Zero
- local: api/pipelines/pixart
title: PixArt-α
- local: api/pipelines/pndm
title: PNDM
- local: api/pipelines/repaint
title: RePaint
- local: api/pipelines/score_sde_ve
title: Score SDE VE
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/spectrogram_diffusion
title: Spectrogram Diffusion
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
@@ -356,26 +340,16 @@
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
- local: api/pipelines/stochastic_karras_ve
title: Stochastic Karras VE
- local: api/pipelines/model_editing
title: Text-to-image model editing
- local: api/pipelines/text_to_video
title: Text-to-video
- local: api/pipelines/text_to_video_zero
title: Text2Video-Zero
- local: api/pipelines/unclip
title: unCLIP
- local: api/pipelines/latent_diffusion_uncond
title: Unconditional Latent Diffusion
- local: api/pipelines/unidiffuser
title: UniDiffuser
- local: api/pipelines/value_guided_sampling
title: Value-guided sampling
- local: api/pipelines/versatile_diffusion
title: Versatile Diffusion
- local: api/pipelines/vq_diffusion
title: VQ Diffusion
- local: api/pipelines/wuerstchen
title: Wuerstchen
title: Pipelines

View File

@@ -49,12 +49,12 @@ make_image_grid([original_image, mask_image, image], rows=1, cols=3)
## AsymmetricAutoencoderKL
[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
[[autodoc]] models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL
## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

View File

@@ -54,4 +54,4 @@ image
## AutoencoderTinyOutput
[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
[[autodoc]] models.autoencoders.autoencoder_tiny.AutoencoderTinyOutput

View File

@@ -36,11 +36,11 @@ model = AutoencoderKL.from_single_file(url)
## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
## FlaxAutoencoderKL

View File

@@ -1,47 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AltDiffusion
AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://huggingface.co/papers/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu.
The abstract from the paper is:
*In this work, we present a conceptually simple and effective method to train a strong bilingual/multilingual multimodal representation model. Starting from the pre-trained multimodal representation model CLIP released by OpenAI, we altered its text encoder with a pre-trained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k-CN, COCO-CN and XTD. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding. Our models and code are available at [this https URL](https://github.com/FlagAI-Open/FlagAI).*
## Tips
`AltDiffusion` is conceptually the same as [Stable Diffusion](./stable_diffusion/overview).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## AltDiffusionPipeline
[[autodoc]] AltDiffusionPipeline
- all
- __call__
## AltDiffusionImg2ImgPipeline
[[autodoc]] AltDiffusionImg2ImgPipeline
- all
- __call__
## AltDiffusionPipelineOutput
[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput
- all
- __call__

View File

@@ -0,0 +1,30 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# aMUSEd
Amused is a lightweight text to image model based off of the [muse](https://arxiv.org/pdf/2301.00704.pdf) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.
Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.
| Model | Params |
|-------|--------|
| [amused-256](https://huggingface.co/huggingface/amused-256) | 603M |
| [amused-512](https://huggingface.co/huggingface/amused-512) | 608M |
## AmusedPipeline
[[autodoc]] AmusedPipeline
- __call__
- all
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention

View File

@@ -1,35 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Audio Diffusion
[Audio Diffusion](https://github.com/teticio/audio-diffusion) is by Robert Dargavel Smith, and it leverages the recent advances in image generation from diffusion models by converting audio samples to and from Mel spectrogram images.
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## AudioDiffusionPipeline
[[autodoc]] AudioDiffusionPipeline
- all
- __call__
## AudioPipelineOutput
[[autodoc]] pipelines.AudioPipelineOutput
## ImagePipelineOutput
[[autodoc]] pipelines.ImagePipelineOutput
## Mel
[[autodoc]] Mel

View File

@@ -1,33 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Cycle Diffusion
Cycle Diffusion is a text guided image-to-image generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://huggingface.co/papers/2210.05559) by Chen Henry Wu, Fernando De la Torre.
The abstract from the paper is:
*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs. The code is publicly available at [this https URL](https://github.com/ChenWu98/cycle-diffusion).*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## CycleDiffusionPipeline
[[autodoc]] CycleDiffusionPipeline
- all
- __call__
## StableDiffusionPiplineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput

View File

@@ -1,35 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Unconditional Latent Diffusion
Unconditional Latent Diffusion was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.
The abstract from the paper is:
*By decomposing the image formation process into a sequential application of denoising autoencoders, diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations. To enable DM training on limited computational resources while retaining their quality and flexibility, we apply them in the latent space of powerful pretrained autoencoders. In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner. Our latent diffusion models (LDMs) achieve a new state of the art for image inpainting and highly competitive performance on various tasks, including unconditional image generation, semantic scene synthesis, and super-resolution, while significantly reducing computational requirements compared to pixel-based DMs.*
The original codebase can be found at [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## LDMPipeline
[[autodoc]] LDMPipeline
- all
- __call__
## ImagePipelineOutput
[[autodoc]] pipelines.ImagePipelineOutput

View File

@@ -1,35 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Text-to-image model editing
[Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://huggingface.co/papers/2303.08084) is by Hadas Orgad, Bahjat Kawar, and Yonatan Belinkov. This pipeline enables editing diffusion model weights, such that its assumptions of a given concept are changed. The resulting change is expected to take effect in all prompt generations related to the edited concept.
The abstract from the paper is:
*Text-to-image diffusion models often make implicit assumptions about the world when generating images. While some assumptions are useful (e.g., the sky is blue), they can also be outdated, incorrect, or reflective of social biases present in the training data. Thus, there is a need to control these assumptions without requiring explicit user input or costly re-training. In this work, we aim to edit a given implicit assumption in a pre-trained diffusion model. Our Text-to-Image Model Editing method, TIME for short, receives a pair of inputs: a "source" under-specified prompt for which the model makes an implicit assumption (e.g., "a pack of roses"), and a "destination" prompt that describes the same setting, but with a specified desired attribute (e.g., "a pack of blue roses"). TIME then updates the model's cross-attention layers, as these layers assign visual meaning to textual tokens. We edit the projection matrices in these layers such that the source prompt is projected close to the destination prompt. Our method is highly efficient, as it modifies a mere 2.2% of the model's parameters in under one second. To evaluate model editing approaches, we introduce TIMED (TIME Dataset), containing 147 source and destination prompt pairs from various domains. Our experiments (using Stable Diffusion) show that TIME is successful in model editing, generalizes well for related prompts unseen during editing, and imposes minimal effect on unrelated generations.*
You can find additional information about model editing on the [project page](https://time-diffusion.github.io/), [original codebase](https://github.com/bahjat-kawar/time-diffusion), and try it out in a [demo](https://huggingface.co/spaces/bahjat-kawar/time-diffusion).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionModelEditingPipeline
[[autodoc]] StableDiffusionModelEditingPipeline
- __call__
- all
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput

View File

@@ -1,51 +0,0 @@
<!--Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Parallel Sampling of Diffusion Models
[Parallel Sampling of Diffusion Models](https://huggingface.co/papers/2305.16317) is by Andy Shih, Suneel Belkhale, Stefano Ermon, Dorsa Sadigh, Nima Anari.
The abstract from the paper is:
*Diffusion models are powerful generative models but suffer from slow sampling, often taking 1000 sequential denoising steps for one sample. As a result, considerable efforts have been directed toward reducing the number of denoising steps, but these methods hurt sample quality. Instead of reducing the number of denoising steps (trading quality for speed), in this paper we explore an orthogonal approach: can we run the denoising steps in parallel (trading compute for speed)? In spite of the sequential nature of the denoising steps, we show that surprisingly it is possible to parallelize sampling via Picard iterations, by guessing the solution of future denoising steps and iteratively refining until convergence. With this insight, we present ParaDiGMS, a novel method to accelerate the sampling of pretrained diffusion models by denoising multiple steps in parallel. ParaDiGMS is the first diffusion sampling method that enables trading compute for speed and is even compatible with existing fast sampling techniques such as DDIM and DPMSolver. Using ParaDiGMS, we improve sampling speed by 2-4x across a range of robotics and image generation models, giving state-of-the-art sampling speeds of 0.2s on 100-step DiffusionPolicy and 14.6s on 1000-step StableDiffusion-v2 with no measurable degradation of task reward, FID score, or CLIP score.*
The original codebase can be found at [AndyShih12/paradigms](https://github.com/AndyShih12/paradigms), and the pipeline was contributed by [AndyShih12](https://github.com/AndyShih12). ❤️
## Tips
This pipeline improves sampling speed by running denoising steps in parallel, at the cost of increased total FLOPs.
Therefore, it is better to call this pipeline when running on multiple GPUs. Otherwise, without enough GPU bandwidth
sampling may be even slower than sequential sampling.
The two parameters to play with are `parallel` (batch size) and `tolerance`.
- If it fits in memory, for a 1000-step DDPM you can aim for a batch size of around 100 (for example, 8 GPUs and `batch_per_device=12` to get `parallel=96`). A higher batch size may not fit in memory, and lower batch size gives less parallelism.
- For tolerance, using a higher tolerance may get better speedups but can risk sample quality degradation. If there is quality degradation with the default tolerance, then use a lower tolerance like `0.001`.
For a 1000-step DDPM on 8 A100 GPUs, you can expect around a 3x speedup from [`StableDiffusionParadigmsPipeline`] compared to the [`StableDiffusionPipeline`]
by setting `parallel=80` and `tolerance=0.1`.
🤗 Diffusers offers [distributed inference support](../../training/distributed_inference) for generating multiple prompts
in parallel on multiple GPUs. But [`StableDiffusionParadigmsPipeline`] is designed for speeding up sampling of a single prompt by using multiple GPUs.
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionParadigmsPipeline
[[autodoc]] StableDiffusionParadigmsPipeline
- __call__
- all
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput

View File

@@ -1,289 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Pix2Pix Zero
[Zero-shot Image-to-Image Translation](https://huggingface.co/papers/2302.03027) is by Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, and Jun-Yan Zhu.
The abstract from the paper is:
*Large-scale text-to-image generative models have shown their remarkable ability to synthesize diverse and high-quality images. However, it is still challenging to directly apply these models for editing real images for two reasons. First, it is hard for users to come up with a perfect text prompt that accurately describes every visual detail in the input image. Second, while existing models can introduce desirable changes in certain regions, they often dramatically alter the input content and introduce unexpected changes in unwanted regions. In this work, we propose pix2pix-zero, an image-to-image translation method that can preserve the content of the original image without manual prompting. We first automatically discover editing directions that reflect desired edits in the text embedding space. To preserve the general content structure after editing, we further propose cross-attention guidance, which aims to retain the cross-attention maps of the input image throughout the diffusion process. In addition, our method does not need additional training for these edits and can directly use the existing pre-trained text-to-image diffusion model. We conduct extensive experiments and show that our method outperforms existing and concurrent works for both real and synthetic image editing.*
You can find additional information about Pix2Pix Zero on the [project page](https://pix2pixzero.github.io/), [original codebase](https://github.com/pix2pixzero/pix2pix-zero), and try it out in a [demo](https://huggingface.co/spaces/pix2pix-zero-library/pix2pix-zero-demo).
## Tips
* The pipeline can be conditioned on real input images. Check out the code examples below to know more.
* The pipeline exposes two arguments namely `source_embeds` and `target_embeds`
that let you control the direction of the semantic edits in the final image to be generated. Let's say,
you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect
this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to
`source_embeds` and "dog" to `target_embeds`. Refer to the code example below for more details.
* When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking
the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gogh".
* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to:
* Swap the `source_embeds` and `target_embeds`.
* Change the input prompt to include "dog".
* To learn more about how the source and target embeddings are generated, refer to the [original paper](https://arxiv.org/abs/2302.03027). Below, we also provide some directions on how to generate the embeddings.
* Note that the quality of the outputs generated with this pipeline is dependent on how good the `source_embeds` and `target_embeds` are. Please, refer to [this discussion](#generating-source-and-target-embeddings) for some suggestions on the topic.
## Available Pipelines:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [StableDiffusionPix2PixZeroPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py) | *Text-Based Image Editing* | [🤗 Space](https://huggingface.co/spaces/pix2pix-zero-library/pix2pix-zero-demo) |
<!-- TODO: add Colab -->
## Usage example
### Based on an image generated with the input prompt
```python
import requests
import torch
from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline
def download(embedding_url, local_filepath):
r = requests.get(embedding_url)
with open(local_filepath, "wb") as f:
f.write(r.content)
model_ckpt = "CompVis/stable-diffusion-v1-4"
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
model_ckpt, conditions_input_image=False, torch_dtype=torch.float16
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.to("cuda")
prompt = "a high resolution painting of a cat in the style of van gogh"
src_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt"
target_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt"
for url in [src_embs_url, target_embs_url]:
download(url, url.split("/")[-1])
src_embeds = torch.load(src_embs_url.split("/")[-1])
target_embeds = torch.load(target_embs_url.split("/")[-1])
image = pipeline(
prompt,
source_embeds=src_embeds,
target_embeds=target_embeds,
num_inference_steps=50,
cross_attention_guidance_amount=0.15,
).images[0]
image
```
### Based on an input image
When the pipeline is conditioned on an input image, we first obtain an inverted
noise from it using a `DDIMInverseScheduler` with the help of a generated caption. Then the inverted noise is used to start the generation process.
First, let's load our pipeline:
```py
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline
captioner_id = "Salesforce/blip-image-captioning-base"
processor = BlipProcessor.from_pretrained(captioner_id)
model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)
sd_model_ckpt = "CompVis/stable-diffusion-v1-4"
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
sd_model_ckpt,
caption_generator=model,
caption_processor=processor,
torch_dtype=torch.float16,
safety_checker=None,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_model_cpu_offload()
```
Then, we load an input image for conditioning and obtain a suitable caption for it:
```py
from diffusers.utils import load_image
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
raw_image = load_image(url).resize((512, 512))
caption = pipeline.generate_caption(raw_image)
caption
```
Then we employ the generated caption and the input image to get the inverted noise:
```py
generator = torch.manual_seed(0)
inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents
```
Now, generate the image with edit directions:
```py
# See the "Generating source and target embeddings" section below to
# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below.
source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
source_embeds = pipeline.get_embeds(source_prompts, batch_size=2)
target_embeds = pipeline.get_embeds(target_prompts, batch_size=2)
image = pipeline(
caption,
source_embeds=source_embeds,
target_embeds=target_embeds,
num_inference_steps=50,
cross_attention_guidance_amount=0.15,
generator=generator,
latents=inv_latents,
negative_prompt=caption,
).images[0]
image
```
## Generating source and target embeddings
The authors originally used the [GPT-3 API](https://openai.com/api/) to generate the source and target captions for discovering
edit directions. However, we can also leverage open source and public models for the same purpose.
Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model
for generating captions and [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for
computing embeddings on the generated captions.
**1. Load the generation model**:
```py
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
```
**2. Construct a starting prompt**:
```py
source_concept = "cat"
target_concept = "dog"
source_text = f"Provide a caption for images containing a {source_concept}. "
"The captions should be in English and should be no longer than 150 characters."
target_text = f"Provide a caption for images containing a {target_concept}. "
"The captions should be in English and should be no longer than 150 characters."
```
Here, we're interested in the "cat -> dog" direction.
**3. Generate captions**:
We can use a utility like so for this purpose.
```py
def generate_captions(input_prompt):
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")
outputs = model.generate(
input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10
)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
```
And then we just call it to generate our captions:
```py
source_captions = generate_captions(source_text)
target_captions = generate_captions(target_concept)
print(source_captions, target_captions, sep='\n')
```
We encourage you to play around with the different parameters supported by the
`generate()` method ([documentation](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.generation_tf_utils.TFGenerationMixin.generate)) for the generation quality you are looking for.
**4. Load the embedding model**:
Here, we need to use the same text encoder model used by the subsequent Stable Diffusion model.
```py
from diffusers import StableDiffusionPix2PixZeroPipeline
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
)
pipeline = pipeline.to("cuda")
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
```
**5. Compute embeddings**:
```py
import torch
def embed_captions(sentences, tokenizer, text_encoder, device="cuda"):
with torch.no_grad():
embeddings = []
for sent in sentences:
text_inputs = tokenizer(
sent,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
embeddings.append(prompt_embeds)
return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)
source_embeddings = embed_captions(source_captions, tokenizer, text_encoder)
target_embeddings = embed_captions(target_captions, tokenizer, text_encoder)
```
And you're done! [Here](https://colab.research.google.com/drive/1tz2C1EdfZYAPlzXXbTnf-5PRBiR8_R1F?usp=sharing) is a Colab Notebook that you can use to interact with the entire process.
Now, you can use these embeddings directly while calling the pipeline:
```py
from diffusers import DDIMScheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
image = pipeline(
prompt,
source_embeds=source_embeddings,
target_embeds=target_embeddings,
num_inference_steps=50,
cross_attention_guidance_amount=0.15,
).images[0]
image
```
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionPix2PixZeroPipeline
[[autodoc]] StableDiffusionPix2PixZeroPipeline
- __call__
- all

View File

@@ -1,35 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PNDM
[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://huggingface.co/papers/2202.09778) (PNDM) is by Luping Liu, Yi Ren, Zhijie Lin and Zhou Zhao.
The abstract from the paper is:
*Denoising Diffusion Probabilistic Models (DDPMs) can generate high-quality samples such as image and audio samples. However, DDPMs require hundreds to thousands of iterations to produce final samples. Several prior works have successfully accelerated DDPMs through adjusting the variance schedule (e.g., Improved Denoising Diffusion Probabilistic Models) or the denoising equation (e.g., Denoising Diffusion Implicit Models (DDIMs)). However, these acceleration methods cannot maintain the quality of samples and even introduce new noise at a high speedup rate, which limit their practicability. To accelerate the inference process while keeping the sample quality, we provide a fresh perspective that DDPMs should be treated as solving differential equations on manifolds. Under such a perspective, we propose pseudo numerical methods for diffusion models (PNDMs). Specifically, we figure out how to solve differential equations on manifolds and show that DDIMs are simple cases of pseudo numerical methods. We change several classical numerical methods to corresponding pseudo numerical methods and find that the pseudo linear multi-step method is the best in most situations. According to our experiments, by directly using pre-trained models on Cifar10, CelebA and LSUN, PNDMs can generate higher quality synthetic images with only 50 steps compared with 1000-step DDIMs (20x speedup), significantly outperform DDIMs with 250 steps (by around 0.4 in FID) and have good generalization on different variance schedules.*
The original codebase can be found at [luping-liu/PNDM](https://github.com/luping-liu/PNDM).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## PNDMPipeline
[[autodoc]] PNDMPipeline
- all
- __call__
## ImagePipelineOutput
[[autodoc]] pipelines.ImagePipelineOutput

View File

@@ -1,37 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# RePaint
[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2201.09865) is by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool.
The abstract from the paper is:
*Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks.
RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions.*
The original codebase can be found at [andreas128/RePaint](https://github.com/andreas128/RePaint).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## RePaintPipeline
[[autodoc]] RePaintPipeline
- all
- __call__
## ImagePipelineOutput
[[autodoc]] pipelines.ImagePipelineOutput

View File

@@ -1,35 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Score SDE VE
[Score-Based Generative Modeling through Stochastic Differential Equations](https://huggingface.co/papers/2011.13456) (Score SDE) is by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon and Ben Poole. This pipeline implements the variance expanding (VE) variant of the stochastic differential equation method.
The abstract from the paper is:
*Creating noise from data is easy; creating data from noise is generative modeling. We present a stochastic differential equation (SDE) that smoothly transforms a complex data distribution to a known prior distribution by slowly injecting noise, and a corresponding reverse-time SDE that transforms the prior distribution back into the data distribution by slowly removing the noise. Crucially, the reverse-time SDE depends only on the time-dependent gradient field (\aka, score) of the perturbed data distribution. By leveraging advances in score-based generative modeling, we can accurately estimate these scores with neural networks, and use numerical SDE solvers to generate samples. We show that this framework encapsulates previous approaches in score-based generative modeling and diffusion probabilistic modeling, allowing for new sampling procedures and new modeling capabilities. In particular, we introduce a predictor-corrector framework to correct errors in the evolution of the discretized reverse-time SDE. We also derive an equivalent neural ODE that samples from the same distribution as the SDE, but additionally enables exact likelihood computation, and improved sampling efficiency. In addition, we provide a new way to solve inverse problems with score-based models, as demonstrated with experiments on class-conditional generation, image inpainting, and colorization. Combined with multiple architectural improvements, we achieve record-breaking performance for unconditional image generation on CIFAR-10 with an Inception score of 9.89 and FID of 2.20, a competitive likelihood of 2.99 bits/dim, and demonstrate high fidelity generation of 1024 x 1024 images for the first time from a score-based generative model.*
The original codebase can be found at [yang-song/score_sde_pytorch](https://github.com/yang-song/score_sde_pytorch).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## ScoreSdeVePipeline
[[autodoc]] ScoreSdeVePipeline
- all
- __call__
## ImagePipelineOutput
[[autodoc]] pipelines.ImagePipelineOutput

View File

@@ -1,37 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Spectrogram Diffusion
[Spectrogram Diffusion](https://huggingface.co/papers/2206.05408) is by Curtis Hawthorne, Ian Simon, Adam Roberts, Neil Zeghidour, Josh Gardner, Ethan Manilow, and Jesse Engel.
*An ideal music synthesizer should be both interactive and expressive, generating high-fidelity audio in realtime for arbitrary combinations of instruments and notes. Recent neural synthesizers have exhibited a tradeoff between domain-specific models that offer detailed control of only specific instruments, or raw waveform models that can train on any music but with minimal control and slow generation. In this work, we focus on a middle ground of neural synthesizers that can generate audio from MIDI sequences with arbitrary combinations of instruments in realtime. This enables training on a wide range of transcription datasets with a single model, which in turn offers note-level control of composition and instrumentation across a wide range of instruments. We use a simple two-stage process: MIDI to spectrograms with an encoder-decoder Transformer, then spectrograms to audio with a generative adversarial network (GAN) spectrogram inverter. We compare training the decoder as an autoregressive model and as a Denoising Diffusion Probabilistic Model (DDPM) and find that the DDPM approach is superior both qualitatively and as measured by audio reconstruction and Fréchet distance metrics. Given the interactivity and generality of this approach, we find this to be a promising first step towards interactive and expressive neural synthesis for arbitrary combinations of instruments and notes.*
The original codebase can be found at [magenta/music-spectrogram-diffusion](https://github.com/magenta/music-spectrogram-diffusion).
![img](https://storage.googleapis.com/music-synthesis-with-spectrogram-diffusion/architecture.png)
As depicted above the model takes as input a MIDI file and tokenizes it into a sequence of 5 second intervals. Each tokenized interval then together with positional encodings is passed through the Note Encoder and its representation is concatenated with the previous window's generated spectrogram representation obtained via the Context Encoder. For the initial 5 second window this is set to zero. The resulting context is then used as conditioning to sample the denoised Spectrogram from the MIDI window and we concatenate this spectrogram to the final output as well as use it for the context of the next MIDI window. The process repeats till we have gone over all the MIDI inputs. Finally a MelGAN decoder converts the potentially long spectrogram to audio which is the final result of this pipeline.
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## SpectrogramDiffusionPipeline
[[autodoc]] SpectrogramDiffusionPipeline
- all
- __call__
## AudioPipelineOutput
[[autodoc]] pipelines.AudioPipelineOutput

View File

@@ -31,14 +31,14 @@ Make sure to check out the Stable Diffusion [Tips](overview#tips) section to lea
## StableDiffusionLDM3DPipeline
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline
[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline
- all
- __call__
## LDM3DPipelineOutput
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
- all
- __call__

View File

@@ -1,33 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stochastic Karras VE
[Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) is by Tero Karras, Miika Aittala, Timo Aila and Samuli Laine. This pipeline implements the stochastic sampling tailored to variance expanding (VE) models.
The abstract from the paper:
*We argue that the theory and practice of diffusion-based generative models are currently unnecessarily convoluted and seek to remedy the situation by presenting a design space that clearly separates the concrete design choices. This lets us identify several changes to both the sampling and training processes, as well as preconditioning of the score networks. Together, our improvements yield new state-of-the-art FID of 1.79 for CIFAR-10 in a class-conditional setting and 1.97 in an unconditional setting, with much faster sampling (35 network evaluations per image) than prior designs. To further demonstrate their modular nature, we show that our design changes dramatically improve both the efficiency and quality obtainable with pre-trained score networks from previous work, including improving the FID of a previously trained ImageNet-64 model from 2.07 to near-SOTA 1.55, and after re-training with our proposed improvements to a new SOTA of 1.36.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## KarrasVePipeline
[[autodoc]] KarrasVePipeline
- all
- __call__
## ImagePipelineOutput
[[autodoc]] pipelines.ImagePipelineOutput

View File

@@ -1,54 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Versatile Diffusion
Versatile Diffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://huggingface.co/papers/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi.
The abstract from the paper is:
*Recent advances in diffusion models have set an impressive milestone in many generation tasks, and trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-task multimodal network, dubbed Versatile Diffusion (VD), that handles multiple flows of text-to-image, image-to-text, and variations in one unified model. The pipeline design of VD instantiates a unified multi-flow diffusion framework, consisting of sharable and swappable layer modules that enable the crossmodal generality beyond images and text. Through extensive experiments, we demonstrate that VD successfully achieves the following: a) VD outperforms the baseline approaches and handles all its base tasks with competitive quality; b) VD enables novel extensions such as disentanglement of style and semantics, dual- and multi-context blending, etc.; c) The success of our multi-flow multimodal framework over images and text may inspire further diffusion-based universal AI research.*
## Tips
You can load the more memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that supports all the tasks or use the individual pipelines which are more memory efficient.
| **Pipeline** | **Supported tasks** |
|------------------------------------------------------|-----------------------------------|
| [`VersatileDiffusionPipeline`] | all of the below |
| [`VersatileDiffusionTextToImagePipeline`] | text-to-image |
| [`VersatileDiffusionImageVariationPipeline`] | image variation |
| [`VersatileDiffusionDualGuidedPipeline`] | image-text dual guided generation |
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## VersatileDiffusionPipeline
[[autodoc]] VersatileDiffusionPipeline
## VersatileDiffusionTextToImagePipeline
[[autodoc]] VersatileDiffusionTextToImagePipeline
- all
- __call__
## VersatileDiffusionImageVariationPipeline
[[autodoc]] VersatileDiffusionImageVariationPipeline
- all
- __call__
## VersatileDiffusionDualGuidedPipeline
[[autodoc]] VersatileDiffusionDualGuidedPipeline
- all
- __call__

View File

@@ -1,35 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# VQ Diffusion
[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://huggingface.co/papers/2111.14822) is by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo.
The abstract from the paper is:
*We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.*
The original codebase can be found at [microsoft/VQ-Diffusion](https://github.com/microsoft/VQ-Diffusion).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## VQDiffusionPipeline
[[autodoc]] VQDiffusionPipeline
- all
- __call__
## ImagePipelineOutput
[[autodoc]] pipelines.ImagePipelineOutput

View File

@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# T2I-Adapter
[T2I-Adapter]((https://hf.co/papers/2302.08453)) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it.
[T2I-Adapter](https://hf.co/papers/2302.08453) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it.
The T2I-Adapter is only available for training with the Stable Diffusion XL (SDXL) model.
@@ -224,4 +224,4 @@ image.save("./output.png")
Congratulations on training a T2I-Adapter model! 🎉 To learn more:
- Read the [Efficient Controllable Generation for SDXL with T2I-Adapters](https://www.cs.cmu.edu/~custom-diffusion/) blog post to learn more details about the experimental results from the T2I-Adapter team.
- Read the [Efficient Controllable Generation for SDXL with T2I-Adapters](https://huggingface.co/blog/t2i-sdxl-adapters) blog post to learn more details about the experimental results from the T2I-Adapter team.

View File

@@ -186,7 +186,7 @@ accelerate launch train_unconditional.py \
If you're training with more than one GPU, add the `--multi_gpu` parameter to the training command:
```bash
accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
accelerate launch --multi_gpu train_unconditional.py \
--dataset_name="huggan/flowers-102-categories" \
--output_dir="ddpm-ema-flowers-64" \
--mixed_precision="fp16" \

View File

@@ -63,3 +63,42 @@ With callbacks, you can implement features such as dynamic CFG without having to
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
</Tip>
## Using Callbacks to interrupt the Diffusion Process
The following Pipelines support interrupting the diffusion process via callback
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.enable_model_cpu_offload()
num_inference_steps = 50
def interrupt_callback(pipe, i, t, callback_kwargs):
stop_idx = 10
if i == stop_idx:
pipe._interrupt = True
return callback_kwargs
pipe(
"A photo of a cat",
num_inference_steps=num_inference_steps,
callback_on_step_end=interrupt_callback,
)
```

View File

@@ -203,7 +203,7 @@ def make_inpaint_condition(image, image_mask):
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[0:1] == image_mask.shape[0:1]
image[image_mask > 0.5] = 1.0 # set as masked pixel
image[image_mask > 0.5] = -1.0 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image

View File

@@ -41,6 +41,20 @@ Now, define four different `Generator`s and assign each `Generator` a seed (`0`
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
```
<Tip warning={true}>
To create a batched seed, you should use a list comprehension that iterates over the length specified in `range()`. This creates a unique `Generator` object for each image in the batch. If you only multiply the `Generator` by the batch size, this only creates one `Generator` object that is used sequentially for each image in the batch.
For example, if you want to use the same seed to create 4 identical images:
```py
[torch.Generator().manual_seed(seed)] * 4
[torch.Generator().manual_seed(seed) for _ in range(4)]
```
</Tip>
Generate the images and have a look:
```python

326
examples/amused/README.md Normal file
View File

@@ -0,0 +1,326 @@
## Amused training
Amused can be finetuned on simple datasets relatively cheaply and quickly. Using 8bit optimizers, lora, and gradient accumulation, amused can be finetuned with as little as 5.5 GB. Here are a set of examples for finetuning amused on some relatively simple datasets. These training recipies are aggressively oriented towards minimal resources and fast verification -- i.e. the batch sizes are quite low and the learning rates are quite high. For optimal quality, you will probably want to increase the batch sizes and decrease learning rates.
All training examples use fp16 mixed precision and gradient checkpointing. We don't show 8 bit adam + lora as its about the same memory use as just using lora (bitsandbytes uses full precision optimizer states for weights below a minimum size).
### Finetuning the 256 checkpoint
These examples finetune on this [nouns](https://huggingface.co/datasets/m1guelpf/nouns) dataset.
Example results:
![noun1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun1.png) ![noun2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun2.png) ![noun3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun3.png)
#### Full finetuning
Batch size: 8, Learning rate: 1e-4, Gives decent results in 750-1000 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 8 | 1 | 8 | 19.7 GB |
| 4 | 2 | 8 | 18.3 GB |
| 1 | 8 | 8 | 17.9 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 1e-4 \
--pretrained_model_name_or_path huggingface/amused-256 \
--instance_data_dataset 'm1guelpf/nouns' \
--image_key image \
--prompt_key text \
--resolution 256 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
'a pixel art character with square red glasses' \
'a pixel art character' \
'square red glasses on a pixel art character' \
'square red glasses on a pixel art character with a baseball-shaped head' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
#### Full finetuning + 8 bit adam
Note that this training config keeps the batch size low and the learning rate high to get results fast with low resources. However, due to 8 bit adam, it will diverge eventually. If you want to train for longer, you will have to up the batch size and lower the learning rate.
Batch size: 16, Learning rate: 2e-5, Gives decent results in ~750 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 16 | 1 | 16 | 20.1 GB |
| 8 | 2 | 16 | 15.6 GB |
| 1 | 16 | 16 | 10.7 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 2e-5 \
--use_8bit_adam \
--pretrained_model_name_or_path huggingface/amused-256 \
--instance_data_dataset 'm1guelpf/nouns' \
--image_key image \
--prompt_key text \
--resolution 256 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
'a pixel art character with square red glasses' \
'a pixel art character' \
'square red glasses on a pixel art character' \
'square red glasses on a pixel art character with a baseball-shaped head' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
#### Full finetuning + lora
Batch size: 16, Learning rate: 8e-4, Gives decent results in 1000-1250 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 16 | 1 | 16 | 14.1 GB |
| 8 | 2 | 16 | 10.1 GB |
| 1 | 16 | 16 | 6.5 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 8e-4 \
--use_lora \
--pretrained_model_name_or_path huggingface/amused-256 \
--instance_data_dataset 'm1guelpf/nouns' \
--image_key image \
--prompt_key text \
--resolution 256 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
'a pixel art character with square red glasses' \
'a pixel art character' \
'square red glasses on a pixel art character' \
'square red glasses on a pixel art character with a baseball-shaped head' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
### Finetuning the 512 checkpoint
These examples finetune on this [minecraft](https://huggingface.co/monadical-labs/minecraft-preview) dataset.
Example results:
![minecraft1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft1.png) ![minecraft2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft2.png) ![minecraft3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft3.png)
#### Full finetuning
Batch size: 8, Learning rate: 8e-5, Gives decent results in 500-1000 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 8 | 1 | 8 | 24.2 GB |
| 4 | 2 | 8 | 19.7 GB |
| 1 | 8 | 8 | 16.99 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 8e-5 \
--pretrained_model_name_or_path huggingface/amused-512 \
--instance_data_dataset 'monadical-labs/minecraft-preview' \
--prompt_prefix 'minecraft ' \
--image_key image \
--prompt_key text \
--resolution 512 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'minecraft Avatar' \
'minecraft character' \
'minecraft' \
'minecraft president' \
'minecraft pig' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
#### Full finetuning + 8 bit adam
Batch size: 8, Learning rate: 5e-6, Gives decent results in 500-1000 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 8 | 1 | 8 | 21.2 GB |
| 4 | 2 | 8 | 13.3 GB |
| 1 | 8 | 8 | 9.9 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 5e-6 \
--pretrained_model_name_or_path huggingface/amused-512 \
--instance_data_dataset 'monadical-labs/minecraft-preview' \
--prompt_prefix 'minecraft ' \
--image_key image \
--prompt_key text \
--resolution 512 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'minecraft Avatar' \
'minecraft character' \
'minecraft' \
'minecraft president' \
'minecraft pig' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
#### Full finetuning + lora
Batch size: 8, Learning rate: 1e-4, Gives decent results in 500-1000 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 8 | 1 | 8 | 12.7 GB |
| 4 | 2 | 8 | 9.0 GB |
| 1 | 8 | 8 | 5.6 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 1e-4 \
--use_lora \
--pretrained_model_name_or_path huggingface/amused-512 \
--instance_data_dataset 'monadical-labs/minecraft-preview' \
--prompt_prefix 'minecraft ' \
--image_key image \
--prompt_key text \
--resolution 512 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'minecraft Avatar' \
'minecraft character' \
'minecraft' \
'minecraft president' \
'minecraft pig' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
### Styledrop
[Styledrop](https://arxiv.org/abs/2306.00983) is an efficient finetuning method for learning a new style from just one or very few images. It has an optional first stage to generate human picked additional training samples. The additional training samples can be used to augment the initial images. Our examples exclude the optional additional image selection stage and instead we just finetune on a single image.
This is our example style image:
![example](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png)
Download it to your local directory with
```sh
wget https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png
```
#### 256
Example results:
![glowing_256_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_1.png) ![glowing_256_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_2.png) ![glowing_256_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_3.png)
Learning rate: 4e-4, Gives decent results in 1500-2000 steps
Memory used: 6.5 GB
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--mixed_precision fp16 \
--report_to wandb \
--use_lora \
--pretrained_model_name_or_path huggingface/amused-256 \
--train_batch_size 1 \
--lr_scheduler constant \
--learning_rate 4e-4 \
--validation_prompts \
'A chihuahua walking on the street in [V] style' \
'A banana on the table in [V] style' \
'A church on the street in [V] style' \
'A tabby cat walking in the forest in [V] style' \
--instance_data_image 'A mushroom in [V] style.png' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 100 \
--resolution 256
```
#### 512
Example results:
![glowing_512_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_1.png) ![glowing_512_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_2.png) ![glowing_512_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_3.png)
Learning rate: 1e-3, Lora alpha 1, Gives decent results in 1500-2000 steps
Memory used: 5.6 GB
```
accelerate launch train_amused.py \
--output_dir <output path> \
--mixed_precision fp16 \
--report_to wandb \
--use_lora \
--pretrained_model_name_or_path huggingface/amused-512 \
--train_batch_size 1 \
--lr_scheduler constant \
--learning_rate 1e-3 \
--validation_prompts \
'A chihuahua walking on the street in [V] style' \
'A banana on the table in [V] style' \
'A church on the street in [V] style' \
'A tabby cat walking in the forest in [V] style' \
--instance_data_image 'A mushroom in [V] style.png' \
--max_train_steps 100000 \
--checkpointing_steps 500 \
--validation_steps 100 \
--resolution 512 \
--lora_alpha 1
```

View File

@@ -0,0 +1,972 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import logging
import math
import os
import shutil
from contextlib import nullcontext
from pathlib import Path
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import DataLoader, Dataset, default_collate
from torchvision import transforms
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
)
import diffusers.optimization
from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
from diffusers.loaders import LoraLoaderMixin
from diffusers.utils import is_wandb_available
if is_wandb_available():
import wandb
logger = get_logger(__name__, log_level="INFO")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--instance_data_dataset",
type=str,
default=None,
required=False,
help="A Hugging Face dataset containing the training images",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--instance_data_image", type=str, default=None, required=False, help="A single training image"
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument("--ema_decay", type=float, default=0.9999)
parser.add_argument("--ema_update_after_step", type=int, default=0)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument(
"--output_dir",
type=str,
default="muse_training",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
"instructions."
),
)
parser.add_argument(
"--logging_steps",
type=int,
default=50,
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more details"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=0.0003,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="wandb",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--validation_prompts", type=str, nargs="*")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
parser.add_argument("--min_masking_rate", type=float, default=0.0)
parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
parser.add_argument("--max_grad_norm", default=None, type=float, help="Max gradient norm.", required=False)
parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
parser.add_argument("--lora_r", default=16, type=int)
parser.add_argument("--lora_alpha", default=32, type=int)
parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
parser.add_argument("--text_encoder_lora_r", default=16, type=int)
parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
parser.add_argument("--train_text_encoder", action="store_true")
parser.add_argument("--image_key", type=str, required=False)
parser.add_argument("--prompt_key", type=str, required=False)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
args = parser.parse_args()
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
num_datasources = sum(
[x is not None for x in [args.instance_data_dir, args.instance_data_image, args.instance_data_dataset]]
)
if num_datasources != 1:
raise ValueError(
"provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`"
)
if args.instance_data_dir is not None:
if not os.path.exists(args.instance_data_dir):
raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
if args.instance_data_image is not None:
if not os.path.exists(args.instance_data_image):
raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}")
if args.instance_data_dataset is not None and (args.image_key is None or args.prompt_key is None):
raise ValueError("`--instance_data_dataset` requires setting `--image_key` and `--prompt_key`")
return args
class InstanceDataRootDataset(Dataset):
def __init__(
self,
instance_data_root,
tokenizer,
size=512,
):
self.size = size
self.tokenizer = tokenizer
self.instance_images_path = list(Path(instance_data_root).iterdir())
def __len__(self):
return len(self.instance_images_path)
def __getitem__(self, index):
image_path = self.instance_images_path[index % len(self.instance_images_path)]
instance_image = Image.open(image_path)
rv = process_image(instance_image, self.size)
prompt = os.path.splitext(os.path.basename(image_path))[0]
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0]
return rv
class InstanceDataImageDataset(Dataset):
def __init__(
self,
instance_data_image,
train_batch_size,
size=512,
):
self.value = process_image(Image.open(instance_data_image), size)
self.train_batch_size = train_batch_size
def __len__(self):
# Needed so a full batch of the data can be returned. Otherwise will return
# batches of size 1
return self.train_batch_size
def __getitem__(self, index):
return self.value
class HuggingFaceDataset(Dataset):
def __init__(
self,
hf_dataset,
tokenizer,
image_key,
prompt_key,
prompt_prefix=None,
size=512,
):
self.size = size
self.image_key = image_key
self.prompt_key = prompt_key
self.tokenizer = tokenizer
self.hf_dataset = hf_dataset
self.prompt_prefix = prompt_prefix
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, index):
item = self.hf_dataset[index]
rv = process_image(item[self.image_key], self.size)
prompt = item[self.prompt_key]
if self.prompt_prefix is not None:
prompt = self.prompt_prefix + prompt
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0]
return rv
def process_image(image, size):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
orig_height = image.height
orig_width = image.width
image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image)
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size))
image = transforms.functional.crop(image, c_top, c_left, size, size)
image = transforms.ToTensor()(image)
micro_conds = torch.tensor(
[orig_width, orig_height, c_top, c_left, 6.0],
)
return {"image": image, "micro_conds": micro_conds}
def tokenize_prompt(tokenizer, prompt):
return tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=77,
return_tensors="pt",
).input_ids
def encode_prompt(text_encoder, input_ids):
outputs = text_encoder(input_ids, return_dict=True, output_hidden_states=True)
encoder_hidden_states = outputs.hidden_states[-2]
cond_embeds = outputs[0]
return encoder_hidden_states, cond_embeds
def main(args):
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_main_process:
accelerator.init_trackers("amused", config=vars(copy.deepcopy(args)))
if args.seed is not None:
set_seed(args.seed)
# TODO - will have to fix loading if training text encoder
text_encoder = CLIPTextModelWithProjection.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant
)
vq_model = VQModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
)
if args.train_text_encoder:
if args.text_encoder_use_lora:
lora_config = LoraConfig(
r=args.text_encoder_lora_r,
lora_alpha=args.text_encoder_lora_alpha,
target_modules=args.text_encoder_lora_target_modules,
)
text_encoder.add_adapter(lora_config)
text_encoder.train()
text_encoder.requires_grad_(True)
else:
text_encoder.eval()
text_encoder.requires_grad_(False)
vq_model.requires_grad_(False)
model = UVit2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
revision=args.revision,
variant=args.variant,
)
if args.use_lora:
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=args.lora_target_modules,
)
model.add_adapter(lora_config)
model.train()
if args.gradient_checkpointing:
model.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
if args.use_ema:
ema = EMAModel(
model.parameters(),
decay=args.ema_decay,
update_after_step=args.ema_update_after_step,
model_cls=UVit2DModel,
model_config=model.config,
)
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
for model_ in models:
if isinstance(model_, type(accelerator.unwrap_model(model))):
if args.use_lora:
transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
else:
model_.save_pretrained(os.path.join(output_dir, "transformer"))
elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
if args.text_encoder_use_lora:
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
else:
model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
else:
raise ValueError(f"unexpected save model: {model_.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
LoraLoaderMixin.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
)
if args.use_ema:
ema.save_pretrained(os.path.join(output_dir, "ema_model"))
def load_model_hook(models, input_dir):
transformer = None
text_encoder_ = None
while len(models) > 0:
model_ = models.pop()
if isinstance(model_, type(accelerator.unwrap_model(model))):
if args.use_lora:
transformer = model_
else:
load_model = UVit2DModel.from_pretrained(os.path.join(input_dir, "transformer"))
model_.load_state_dict(load_model.state_dict())
del load_model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
if args.text_encoder_use_lora:
text_encoder_ = model_
else:
load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder"))
model_.load_state_dict(load_model.state_dict())
del load_model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer is not None or text_encoder_ is not None:
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)
LoraLoaderMixin.load_lora_into_transformer(
lora_state_dict, network_alphas=network_alphas, transformer=transformer
)
if args.use_ema:
load_from = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=UVit2DModel)
ema.load_state_dict(load_from.state_dict())
del load_from
accelerator.register_load_state_pre_hook(load_model_hook)
accelerator.register_save_state_pre_hook(save_model_hook)
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
)
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
# no decay on bias and layernorm and embedding
no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.adam_weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if args.train_text_encoder:
optimizer_grouped_parameters.append(
{"params": text_encoder.parameters(), "weight_decay": args.adam_weight_decay}
)
optimizer = optimizer_cls(
optimizer_grouped_parameters,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
logger.info("Creating dataloaders and lr_scheduler")
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
if args.instance_data_dir is not None:
dataset = InstanceDataRootDataset(
instance_data_root=args.instance_data_dir,
tokenizer=tokenizer,
size=args.resolution,
)
elif args.instance_data_image is not None:
dataset = InstanceDataImageDataset(
instance_data_image=args.instance_data_image,
train_batch_size=args.train_batch_size,
size=args.resolution,
)
elif args.instance_data_dataset is not None:
dataset = HuggingFaceDataset(
hf_dataset=load_dataset(args.instance_data_dataset, split="train"),
tokenizer=tokenizer,
image_key=args.image_key,
prompt_key=args.prompt_key,
prompt_prefix=args.prompt_prefix,
size=args.resolution,
)
else:
assert False
train_dataloader = DataLoader(
dataset,
batch_size=args.train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
collate_fn=default_collate,
)
train_dataloader.num_batches = len(train_dataloader)
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
)
logger.info("Preparing model, optimizer and dataloaders")
if args.train_text_encoder:
model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare(
model, optimizer, lr_scheduler, train_dataloader, text_encoder
)
else:
model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
model, optimizer, lr_scheduler, train_dataloader
)
train_dataloader.num_batches = len(train_dataloader)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if not args.train_text_encoder:
text_encoder.to(device=accelerator.device, dtype=weight_dtype)
vq_model.to(device=accelerator.device)
if args.use_ema:
ema.to(accelerator.device)
with nullcontext() if args.train_text_encoder else torch.no_grad():
empty_embeds, empty_clip_embeds = encode_prompt(
text_encoder, tokenize_prompt(tokenizer, "").to(text_encoder.device, non_blocking=True)
)
# There is a single image, we can just pre-encode the single prompt
if args.instance_data_image is not None:
prompt = os.path.splitext(os.path.basename(args.instance_data_image))[0]
encoder_hidden_states, cond_embeds = encode_prompt(
text_encoder, tokenize_prompt(tokenizer, prompt).to(text_encoder.device, non_blocking=True)
)
encoder_hidden_states = encoder_hidden_states.repeat(args.train_batch_size, 1, 1)
cond_embeds = cond_embeds.repeat(args.train_batch_size, 1)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
# Afterwards we recalculate our number of training epochs.
# Note: We are not doing epoch based training here, but just using this for book keeping and being able to
# reuse the same training loop with other datasets/loaders.
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
logger.info("***** Running training *****")
logger.info(f" Num training steps = {args.max_train_steps}")
logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
resume_from_checkpoint = args.resume_from_checkpoint
if resume_from_checkpoint:
if resume_from_checkpoint == "latest":
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
if len(dirs) > 0:
resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
else:
resume_from_checkpoint = None
if resume_from_checkpoint is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
else:
accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
if resume_from_checkpoint is None:
global_step = 0
first_epoch = 0
else:
accelerator.load_state(resume_from_checkpoint)
global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
first_epoch = global_step // num_update_steps_per_epoch
# As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
# reuse the same training loop with other datasets/loaders.
for epoch in range(first_epoch, num_train_epochs):
for batch in train_dataloader:
with torch.no_grad():
micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True)
pixel_values = batch["image"].to(accelerator.device, non_blocking=True)
batch_size = pixel_values.shape[0]
split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size
num_splits = math.ceil(batch_size / split_batch_size)
image_tokens = []
for i in range(num_splits):
start_idx = i * split_batch_size
end_idx = min((i + 1) * split_batch_size, batch_size)
bs = pixel_values.shape[0]
image_tokens.append(
vq_model.quantize(vq_model.encode(pixel_values[start_idx:end_idx]).latents)[2][2].reshape(
bs, -1
)
)
image_tokens = torch.cat(image_tokens, dim=0)
batch_size, seq_len = image_tokens.shape
timesteps = torch.rand(batch_size, device=image_tokens.device)
mask_prob = torch.cos(timesteps * math.pi * 0.5)
mask_prob = mask_prob.clip(args.min_masking_rate)
num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
mask = batch_randperm < num_token_masked.unsqueeze(-1)
mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
input_ids = torch.where(mask, mask_id, image_tokens)
labels = torch.where(mask, image_tokens, -100)
if args.cond_dropout_prob > 0.0:
assert encoder_hidden_states is not None
batch_size = encoder_hidden_states.shape[0]
mask = (
torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)
< args.cond_dropout_prob
)
empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)
encoder_hidden_states = torch.where(
(encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_
)
empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)
cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)
bs = input_ids.shape[0]
vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
resolution = args.resolution // vae_scale_factor
input_ids = input_ids.reshape(bs, resolution, resolution)
if "prompt_input_ids" in batch:
with nullcontext() if args.train_text_encoder else torch.no_grad():
encoder_hidden_states, cond_embeds = encode_prompt(
text_encoder, batch["prompt_input_ids"].to(accelerator.device, non_blocking=True)
)
# Train Step
with accelerator.accumulate(model):
codebook_size = accelerator.unwrap_model(model).config.codebook_size
logits = (
model(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
micro_conds=micro_conds,
pooled_text_emb=cond_embeds,
)
.reshape(bs, codebook_size, -1)
.permute(0, 2, 1)
.reshape(-1, codebook_size)
)
loss = F.cross_entropy(
logits,
labels.view(-1),
ignore_index=-100,
reduction="mean",
)
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean()
accelerator.backward(loss)
if args.max_grad_norm is not None and accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema.step(model.parameters())
if (global_step + 1) % args.logging_steps == 0:
logs = {
"step_loss": avg_loss.item(),
"lr": lr_scheduler.get_last_lr()[0],
"avg_masking_rate": avg_masking_rate.item(),
}
accelerator.log(logs, step=global_step + 1)
logger.info(
f"Step: {global_step + 1} "
f"Loss: {avg_loss.item():0.4f} "
f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
)
if (global_step + 1) % args.checkpointing_steps == 0:
save_checkpoint(args, accelerator, global_step + 1)
if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
if args.use_ema:
ema.store(model.parameters())
ema.copy_to(model.parameters())
with torch.no_grad():
logger.info("Generating images...")
model.eval()
if args.train_text_encoder:
text_encoder.eval()
scheduler = AmusedScheduler.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="scheduler",
revision=args.revision,
variant=args.variant,
)
pipe = AmusedPipeline(
transformer=accelerator.unwrap_model(model),
tokenizer=tokenizer,
text_encoder=text_encoder,
vqvae=vq_model,
scheduler=scheduler,
)
pil_images = pipe(prompt=args.validation_prompts).images
wandb_images = [
wandb.Image(image, caption=args.validation_prompts[i])
for i, image in enumerate(pil_images)
]
wandb.log({"generated_images": wandb_images}, step=global_step + 1)
model.train()
if args.train_text_encoder:
text_encoder.train()
if args.use_ema:
ema.restore(model.parameters())
global_step += 1
# Stop training if max steps is reached
if global_step >= args.max_train_steps:
break
# End for
accelerator.wait_for_everyone()
# Evaluate and save checkpoint at the end of training
save_checkpoint(args, accelerator, global_step)
# Save the final trained checkpoint
if accelerator.is_main_process:
model = accelerator.unwrap_model(model)
if args.use_ema:
ema.copy_to(model.parameters())
model.save_pretrained(args.output_dir)
accelerator.end_training()
def save_checkpoint(args, accelerator, global_step):
output_dir = args.output_dir
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if accelerator.is_main_process and args.checkpoints_total_limit is not None:
checkpoints = os.listdir(output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = Path(output_dir) / f"checkpoint-{global_step}"
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if __name__ == "__main__":
main(parse_args())

View File

@@ -8,6 +8,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
@@ -41,7 +42,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | - | [Andrew Zhu](https://xhinker.medium.com/) |
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
@@ -61,6 +62,53 @@ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custo
## Example usages
### Marigold Depth Estimation
Marigold is a universal monocular depth estimator that delivers accurate and sharp predictions in the wild. Based on Stable Diffusion, it is trained exclusively with synthetic depth data and excels in zero-shot adaptation to real-world imagery. This pipeline is an official implementation of the inference process. More details can be found on our [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) (also implemented with diffusers).
![Marigold Teaser](https://marigoldmonodepth.github.io/images/teaser_collage_compressed.jpg)
This depth estimation pipeline processes a single input image through multiple diffusion denoising stages to estimate depth maps. These maps are subsequently merged to produce the final output. Below is an example code snippet, including optional arguments:
```python
import numpy as np
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import load_image
pipe = DiffusionPipeline.from_pretrained(
"Bingxin/Marigold",
custom_pipeline="marigold_depth_estimation"
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
)
pipe.to("cuda")
img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_example.jpg"
image: Image.Image = load_image(img_path_or_url)
pipeline_output = pipe(
image, # Input image.
# denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
# ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral".
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
)
depth: np.ndarray = pipeline_output.depth_np # Predicted depth map
depth_colored: Image.Image = pipeline_output.depth_colored # Colorized prediction
# Save as uint16 PNG
depth_uint16 = (depth * 65535.0).astype(np.uint16)
Image.fromarray(depth_uint16).save("./depth_map.png", mode="I;16")
# Save colorized depth map
depth_colored.save("./depth_colored.png")
```
### LLM-grounded Diffusion
LMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion)
@@ -1619,10 +1667,11 @@ This approach is using (optional) CoCa model to avoid writing image description.
This SDXL pipeline support unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
You can provide both `prompt` and `prompt_2`. if only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
You can provide both `prompt` and `prompt_2`. If only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
```python
from diffusers import DiffusionPipeline
from diffusers.utils import load_image
import torch
pipe = DiffusionPipeline.from_pretrained(
@@ -1633,25 +1682,52 @@ pipe = DiffusionPipeline.from_pretrained(
, custom_pipeline = "lpw_stable_diffusion_xl",
)
prompt = "photo of a cute (white) cat running on the grass"*20
prompt2 = "chasing (birds:1.5)"*20
prompt = "photo of a cute (white) cat running on the grass" * 20
prompt2 = "chasing (birds:1.5)" * 20
prompt = f"{prompt},{prompt2}"
neg_prompt = "blur, low quality, carton, animate"
pipe.to("cuda")
images = pipe(
prompt = prompt
, negative_prompt = neg_prompt
).images[0]
# text2img
t2i_images = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
).images # alternatively, you can call the .text2img() function
# img2img
input_image = load_image("/path/to/local/image.png") # or URL to your input image
i2i_images = pipe.img2img(
prompt=prompt,
negative_prompt=neg_prompt,
image=input_image,
strength=0.8, # higher strength will result in more variation compared to original image
).images
# inpaint
input_mask = load_image("/path/to/local/mask.png") # or URL to your input inpainting mask
inpaint_images = pipe.inpaint(
prompt="photo of a cute (black) cat running on the grass" * 20,
negative_prompt=neg_prompt,
image=input_image,
mask=input_mask,
strength=0.6, # higher strength will result in more variation compared to original image
).images
pipe.to("cpu")
torch.cuda.empty_cache()
images
from IPython.display import display # assuming you are using this code in a notebook
display(t2i_images[0])
display(i2i_images[0])
display(inpaint_images[0])
```
In the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. "birds" are showing up in the result.
![Stable Diffusion XL Long Weighted Prompt Pipeline sample](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_long_weighted_prompt.png)
For more results, checkout [PR #6114](https://github.com/huggingface/diffusers/pull/6114).
## Example Images Mixing (with CoCa)
```python
import requests

View File

@@ -11,10 +11,11 @@ import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import (
@@ -23,7 +24,7 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
is_accelerate_available,
@@ -461,6 +462,65 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
@@ -526,6 +586,9 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.default_sample_size = self.unet.config.sample_size
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -813,6 +876,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
prompt_2,
height,
width,
strength,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
@@ -824,6 +888,9 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
@@ -880,23 +947,263 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
else:
t_start = 0
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
# Strength is irrelevant if we directly request a timestep to start at;
# that is, strength is determined by the denoising_start instead.
if denoising_start is not None:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start
def prepare_latents(
self,
image,
mask,
width,
height,
num_channels_latents,
timestep,
batch_size,
num_images_per_prompt,
dtype,
device,
generator=None,
add_noise=True,
latents=None,
is_strength_max=True,
return_noise=False,
return_image_latents=False,
):
batch_size *= num_images_per_prompt
if image is None:
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
elif mask is None:
if not isinstance(image, (torch.Tensor, Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
init_latents = image
else:
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
self.vae.to(dtype=torch.float32)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
init_latents = init_latents.to(dtype)
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents], dim=0)
if add_noise:
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
else:
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None and add_noise:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
elif add_noise:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
else:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
dtype = image.dtype
if self.vae.config.force_upcast:
image = image.float()
self.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
image_latents = image_latents.to(dtype)
image_latents = self.vae.config.scaling_factor * image_latents
return image_latents
def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
@@ -934,15 +1241,52 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def denoising_end(self):
return self._denoising_end
@property
def denoising_start(self):
return self._denoising_start
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str = None,
prompt_2: Optional[str] = None,
image: Optional[PipelineImageInput] = None,
mask_image: Optional[PipelineImageInput] = None,
masked_image_latents: Optional[torch.FloatTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
num_inference_steps: int = 50,
timesteps: List[int] = None,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[str] = None,
@@ -975,20 +1319,46 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
prompt_2 (`str`):
The prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
image (`PipelineImageInput`, *optional*):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
mask_image (`PipelineImageInput`, *optional*):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -1084,6 +1454,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
prompt_2,
height,
width,
strength,
callback_steps,
negative_prompt,
negative_prompt_2,
@@ -1093,6 +1464,12 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
negative_pooled_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1121,28 +1498,126 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
) = get_weighted_text_embeddings_sdxl(
pipe=self, prompt=prompt, neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt
)
dtype = prompt_embeds.dtype
if isinstance(image, Image.Image):
image = self.image_processor.preprocess(image, height=height, width=width)
if image is not None:
image = image.to(device=self.device, dtype=dtype)
if isinstance(mask_image, Image.Image):
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
else:
mask = mask_image
if mask_image is not None:
mask = mask.to(device=self.device, dtype=dtype)
if masked_image_latents is not None:
masked_image = masked_image_latents
elif image.shape[1] == 4:
# if image is in latent space, we can't mask it
masked_image = None
else:
masked_image = image * (mask < 0.5)
else:
mask = None
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
def denoising_value_valid(dnv):
return isinstance(self.denoising_end, float) and 0 < dnv < 1
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
if image is not None:
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
strength,
device,
denoising_start=self.denoising_start if denoising_value_valid else None,
)
# check that number of inference steps is not < 1 - as this doesn't make sense
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
is_strength_max = strength == 1.0
add_noise = True if self.denoising_start is None else False
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
image=image,
mask=mask,
width=width,
height=height,
num_channels_latents=num_channels_unet,
timestep=latent_timestep,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
add_noise=add_noise,
latents=latents,
is_strength_max=is_strength_max,
return_noise=True,
return_image_latents=return_image_latents,
)
if mask is not None:
if return_image_latents:
latents, noise, image_latents = latents
else:
latents, noise = latents
# 5.1. Prepare mask latent variables
if mask is not None:
mask, masked_image_latents = self.prepare_mask_latents(
mask=mask,
masked_image=masked_image,
batch_size=batch_size * num_images_per_prompt,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
raise ValueError(
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
height, width = latents.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
@@ -1158,20 +1633,41 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
if (
self.denoising_end is not None
and self.denoising_start is not None
and denoising_value_valid(self.denoising_end)
and denoising_value_valid(self.denoising_start)
and self.denoising_start >= self.denoising_end
):
raise ValueError(
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {self.denoising_end} when using type float."
)
elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
# 8. Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
self._num_timesteps = len(timesteps)
# 9. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
@@ -1179,13 +1675,17 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if mask is not None and num_channels_unet == 9:
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
@@ -1202,6 +1702,22 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if mask is not None and num_channels_unet == 4:
init_latents_proper = image_latents
if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
else:
init_mask = mask
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, noise, torch.tensor([noise_timestep])
)
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1241,6 +1757,204 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
return StableDiffusionXLPipelineOutput(images=image)
def text2img(
self,
prompt: str = None,
prompt_2: Optional[str] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
):
return self.__call__(
prompt=prompt,
prompt_2=prompt_2,
height=height,
width=width,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
denoising_start=denoising_start,
denoising_end=denoising_end,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
guidance_rescale=guidance_rescale,
original_size=original_size,
crops_coords_top_left=crops_coords_top_left,
target_size=target_size,
)
def img2img(
self,
prompt: str = None,
prompt_2: Optional[str] = None,
image: Optional[PipelineImageInput] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
num_inference_steps: int = 50,
timesteps: List[int] = None,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
):
return self.__call__(
prompt=prompt,
prompt_2=prompt_2,
image=image,
height=height,
width=width,
strength=strength,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
denoising_start=denoising_start,
denoising_end=denoising_end,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
guidance_rescale=guidance_rescale,
original_size=original_size,
crops_coords_top_left=crops_coords_top_left,
target_size=target_size,
)
def inpaint(
self,
prompt: str = None,
prompt_2: Optional[str] = None,
image: Optional[PipelineImageInput] = None,
mask_image: Optional[PipelineImageInput] = None,
masked_image_latents: Optional[torch.FloatTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
num_inference_steps: int = 50,
timesteps: List[int] = None,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
):
return self.__call__(
prompt=prompt,
prompt_2=prompt_2,
image=image,
mask_image=mask_image,
masked_image_latents=masked_image_latents,
height=height,
width=width,
strength=strength,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
denoising_start=denoising_start,
denoising_end=denoising_end,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
guidance_rescale=guidance_rescale,
original_size=original_size,
crops_coords_top_left=crops_coords_top_left,
target_size=target_size,
)
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass

View File

@@ -0,0 +1,602 @@
# Copyright 2023 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
import math
from typing import Dict, Union
import matplotlib
import numpy as np
import torch
from PIL import Image
from scipy.optimize import minimize
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.20.1.dev0")
class MarigoldDepthOutput(BaseOutput):
"""
Output class for Marigold monocular depth prediction pipeline.
Args:
depth_np (`np.ndarray`):
Predicted depth map, with depth values in the range of [0, 1].
depth_colored (`PIL.Image.Image`):
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
uncertainty (`None` or `np.ndarray`):
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
"""
depth_np: np.ndarray
depth_colored: Image.Image
uncertainty: Union[None, np.ndarray]
class MarigoldPipeline(DiffusionPipeline):
"""
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
unet (`UNet2DConditionModel`):
Conditional U-Net to denoise the depth latent, conditioned on image latent.
vae (`AutoencoderKL`):
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
to and from latent representations.
scheduler (`DDIMScheduler`):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
text_encoder (`CLIPTextModel`):
Text-encoder, for empty text embedding.
tokenizer (`CLIPTokenizer`):
CLIP tokenizer.
"""
rgb_latent_scale_factor = 0.18215
depth_latent_scale_factor = 0.18215
def __init__(
self,
unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: DDIMScheduler,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
):
super().__init__()
self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
self.empty_text_embed = None
@torch.no_grad()
def __call__(
self,
input_image: Image,
denoising_steps: int = 10,
ensemble_size: int = 10,
processing_res: int = 768,
match_input_res: bool = True,
batch_size: int = 0,
color_map: str = "Spectral",
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
) -> MarigoldDepthOutput:
"""
Function invoked when calling the pipeline.
Args:
input_image (`Image`):
Input RGB (or gray-scale) image.
processing_res (`int`, *optional*, defaults to `768`):
Maximum resolution of processing.
If set to 0: will not resize at all.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `limit_input_res` is not None.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`):
Colormap used to colorize the depth map.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
Returns:
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1]
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
coming from ensembling. None if `ensemble_size = 1`
"""
device = self.device
input_size = input_image.size
if not match_input_res:
assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
assert processing_res >= 0
assert denoising_steps >= 1
assert ensemble_size >= 1
# ----------------- Image Preprocess -----------------
# Resize image
if processing_res > 0:
input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res)
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
input_image = input_image.convert("RGB")
image = np.asarray(input_image)
# Normalize rgb values
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
rgb_norm = rgb / 255.0
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
rgb_norm = rgb_norm.to(device)
assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0
# ----------------- Predicting depth -----------------
# Batch repeated input image
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
single_rgb_dataset = TensorDataset(duplicated_rgb)
if batch_size > 0:
_bs = batch_size
else:
_bs = self._find_batch_size(
ensemble_size=ensemble_size,
input_res=max(rgb_norm.shape[1:]),
dtype=self.dtype,
)
single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False)
# Predict depth maps (batched)
depth_pred_ls = []
if show_progress_bar:
iterable = tqdm(single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False)
else:
iterable = single_rgb_loader
for batch in iterable:
(batched_img,) = batch
depth_pred_raw = self.single_infer(
rgb_in=batched_img,
num_inference_steps=denoising_steps,
show_pbar=show_progress_bar,
)
depth_pred_ls.append(depth_pred_raw.detach().clone())
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
torch.cuda.empty_cache() # clear vram cache for ensembling
# ----------------- Test-time ensembling -----------------
if ensemble_size > 1:
depth_pred, pred_uncert = self.ensemble_depths(depth_preds, **(ensemble_kwargs or {}))
else:
depth_pred = depth_preds
pred_uncert = None
# ----------------- Post processing -----------------
# Scale prediction to [0, 1]
min_d = torch.min(depth_pred)
max_d = torch.max(depth_pred)
depth_pred = (depth_pred - min_d) / (max_d - min_d)
# Convert to numpy
depth_pred = depth_pred.cpu().numpy().astype(np.float32)
# Resize back to original resolution
if match_input_res:
pred_img = Image.fromarray(depth_pred)
pred_img = pred_img.resize(input_size)
depth_pred = np.asarray(pred_img)
# Clip output range
depth_pred = depth_pred.clip(0, 1)
# Colorize
depth_colored = self.colorize_depth_maps(
depth_pred, 0, 1, cmap=color_map
).squeeze() # [3, H, W], value in (0, 1)
depth_colored = (depth_colored * 255).astype(np.uint8)
depth_colored_hwc = self.chw2hwc(depth_colored)
depth_colored_img = Image.fromarray(depth_colored_hwc)
return MarigoldDepthOutput(
depth_np=depth_pred,
depth_colored=depth_colored_img,
uncertainty=pred_uncert,
)
def _encode_empty_text(self):
"""
Encode text embedding for empty prompt.
"""
prompt = ""
text_inputs = self.tokenizer(
prompt,
padding="do_not_pad",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
@torch.no_grad()
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor:
"""
Perform an individual depth prediction without ensembling.
Args:
rgb_in (`torch.Tensor`):
Input RGB image.
num_inference_steps (`int`):
Number of diffusion denoisign steps (DDIM) during inference.
show_pbar (`bool`):
Display a progress bar of diffusion denoising.
Returns:
`torch.Tensor`: Predicted depth map.
"""
device = rgb_in.device
# Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps # [T]
# Encode image
rgb_latent = self._encode_rgb(rgb_in)
# Initial depth map (noise)
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w]
# Batched empty text embedding
if self.empty_text_embed is None:
self._encode_empty_text()
batch_empty_text_embed = self.empty_text_embed.repeat((rgb_latent.shape[0], 1, 1)) # [B, 2, 1024]
# Denoising loop
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
unet_input = torch.cat([rgb_latent, depth_latent], dim=1) # this order is important
# predict the noise residual
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]
# compute the previous noisy sample x_t -> x_t-1
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
torch.cuda.empty_cache()
depth = self._decode_depth(depth_latent)
# clip prediction
depth = torch.clip(depth, -1.0, 1.0)
# shift to [0, 1]
depth = (depth + 1.0) / 2.0
return depth
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (`torch.Tensor`):
Input RGB image to be encoded.
Returns:
`torch.Tensor`: Image latent.
"""
# encode
h = self.vae.encoder(rgb_in)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
# scale latent
rgb_latent = mean * self.rgb_latent_scale_factor
return rgb_latent
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
"""
Decode depth latent into depth map.
Args:
depth_latent (`torch.Tensor`):
Depth latent to be decoded.
Returns:
`torch.Tensor`: Decoded depth map.
"""
# scale latent
depth_latent = depth_latent / self.depth_latent_scale_factor
# decode
z = self.vae.post_quant_conv(depth_latent)
stacked = self.vae.decoder(z)
# mean of output channels
depth_mean = stacked.mean(dim=1, keepdim=True)
return depth_mean
@staticmethod
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`Image.Image`):
Image to be resized.
max_edge_resolution (`int`):
Maximum edge length (pixel).
Returns:
`Image.Image`: Resized image.
"""
original_width, original_height = img.size
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
resized_img = img.resize((new_width, new_height))
return resized_img
@staticmethod
def colorize_depth_maps(depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None):
"""
Colorize depth maps.
"""
assert len(depth_map.shape) >= 2, "Invalid dimension"
if isinstance(depth_map, torch.Tensor):
depth = depth_map.detach().clone().squeeze().numpy()
elif isinstance(depth_map, np.ndarray):
depth = depth_map.copy().squeeze()
# reshape to [ (B,) H, W ]
if depth.ndim < 3:
depth = depth[np.newaxis, :, :]
# colorize
cm = matplotlib.colormaps[cmap]
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
if valid_mask is not None:
if isinstance(depth_map, torch.Tensor):
valid_mask = valid_mask.detach().numpy()
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
if valid_mask.ndim < 3:
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
else:
valid_mask = valid_mask[:, np.newaxis, :, :]
valid_mask = np.repeat(valid_mask, 3, axis=1)
img_colored_np[~valid_mask] = 0
if isinstance(depth_map, torch.Tensor):
img_colored = torch.from_numpy(img_colored_np).float()
elif isinstance(depth_map, np.ndarray):
img_colored = img_colored_np
return img_colored
@staticmethod
def chw2hwc(chw):
assert 3 == len(chw.shape)
if isinstance(chw, torch.Tensor):
hwc = torch.permute(chw, (1, 2, 0))
elif isinstance(chw, np.ndarray):
hwc = np.moveaxis(chw, 0, -1)
return hwc
@staticmethod
def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
"""
Automatically search for suitable operating batch size.
Args:
ensemble_size (`int`):
Number of predictions to be ensembled.
input_res (`int`):
Operating resolution of the input image.
Returns:
`int`: Operating batch size.
"""
# Search table for suggested max. inference batch size
bs_search_table = [
# tested on A100-PCIE-80GB
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
# tested on A100-PCIE-40GB
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
# tested on RTX3090, RTX4090
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
# tested on GTX1080Ti
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
]
if not torch.cuda.is_available():
return 1
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
for settings in sorted(
filtered_bs_search_table,
key=lambda k: (k["res"], -k["total_vram"]),
):
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
bs = settings["bs"]
if bs > ensemble_size:
bs = ensemble_size
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
bs = math.ceil(ensemble_size / 2)
return bs
return 1
@staticmethod
def ensemble_depths(
input_images: torch.Tensor,
regularizer_strength: float = 0.02,
max_iter: int = 2,
tol: float = 1e-3,
reduction: str = "median",
max_res: int = None,
):
"""
To ensemble multiple affine-invariant depth images (up to scale and shift),
by aligning estimating the scale and shift
"""
def inter_distances(tensors: torch.Tensor):
"""
To calculate the distance between each two depth maps.
"""
distances = []
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
arr1 = tensors[i : i + 1]
arr2 = tensors[j : j + 1]
distances.append(arr1 - arr2)
dist = torch.concatenate(distances, dim=0)
return dist
device = input_images.device
dtype = input_images.dtype
np_dtype = np.float32
original_input = input_images.clone()
n_img = input_images.shape[0]
ori_shape = input_images.shape
if max_res is not None:
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
if scale_factor < 1:
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
input_images = downscaler(torch.from_numpy(input_images)).numpy()
# init guess
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
input_images = input_images.to(device)
# objective function
def closure(x):
l = len(x)
s = x[: int(l / 2)]
t = x[int(l / 2) :]
s = torch.from_numpy(s).to(dtype=dtype).to(device)
t = torch.from_numpy(t).to(dtype=dtype).to(device)
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
dists = inter_distances(transformed_arrays)
sqrt_dist = torch.sqrt(torch.mean(dists**2))
if "mean" == reduction:
pred = torch.mean(transformed_arrays, dim=0)
elif "median" == reduction:
pred = torch.median(transformed_arrays, dim=0).values
else:
raise ValueError
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
err = sqrt_dist + (near_err + far_err) * regularizer_strength
err = err.detach().cpu().numpy().astype(np_dtype)
return err
res = minimize(
closure,
x,
method="BFGS",
tol=tol,
options={"maxiter": max_iter, "disp": False},
)
x = res.x
l = len(x)
s = x[: int(l / 2)]
t = x[int(l / 2) :]
# Prediction
s = torch.from_numpy(s).to(dtype=dtype).to(device)
t = torch.from_numpy(t).to(dtype=dtype).to(device)
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
if "mean" == reduction:
aligned_images = torch.mean(transformed_arrays, dim=0)
std = torch.std(transformed_arrays, dim=0)
uncertainty = std
elif "median" == reduction:
aligned_images = torch.median(transformed_arrays, dim=0).values
# MAD (median absolute deviation) as uncertainty indicator
abs_dev = torch.abs(transformed_arrays - aligned_images)
mad = torch.median(abs_dev, dim=0).values
uncertainty = mad
else:
raise ValueError(f"Unknown reduction method: {reduction}")
# Scale and shift to [0, 1]
_min = torch.min(aligned_images)
_max = torch.max(aligned_images)
aligned_images = (aligned_images - _min) / (_max - _min)
uncertainty /= _max - _min
return aligned_images, uncertainty

View File

@@ -73,7 +73,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
self.register_modules(
vae=vae,
@@ -102,22 +109,22 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return_dict: bool = True,
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
if negative_prompt is None:
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
device = self._execution_device
regions = 0
self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
equal = len(all_prompts_cn) == len(all_n_prompts_cn)
if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
@@ -129,7 +136,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return torch.cat(embl)
conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
unconds = getcompelembs(all_n_prompts_cn)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None
@@ -137,7 +144,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = (
self.encode_prompt(n_prompts, device, 1, True)[0]
if cn
if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)
embs = n_embs = None
@@ -206,8 +213,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
else:
px, nx = hidden_states.chunk(2)
if cn:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
if equal:
hidden_states = torch.cat(
[px for i in range(regions)] + [nx for i in range(regions)],
0,
)
encoder_hidden_states = torch.cat([conds] + [unconds])
else:
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
@@ -289,9 +299,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
if any(x in mode for x in ["COL", "ROW"]):
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
outs = [px, nx] if cn else [px]
px = reshaped[0:center] if equal else reshaped[0:-batch]
nx = reshaped[center:] if equal else reshaped[-batch:]
outs = [px, nx] if equal else [px]
for out in outs:
c = 0
for i, ocell in enumerate(ocells):
@@ -321,15 +331,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
:,
]
c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
hidden_states = hidden_states.reshape(xshape)
#### Regional Prompting Prompt mode
elif "PRO" in mode:
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
px, nx = (
torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
hidden_states[-batch:],
)
if (h, w) in self.attnmasks and self.maskready:
@@ -340,8 +351,8 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
out[b] = out[b] + out[r * batch + b]
return out
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
return hidden_states
@@ -378,7 +389,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
save_mask = False
if mode == "PROMPT" and save_mask:
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
saveattnmaps(
self,
output,
height,
width,
thresholds,
num_inference_steps // 2,
regions,
)
return output
@@ -437,7 +456,11 @@ def make_cells(ratios):
def make_emblist(self, prompts):
with torch.no_grad():
tokens = self.tokenizer(
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
prompts,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
return embs
@@ -563,7 +586,15 @@ def tokendealer(self, all_prompts):
def scaled_dot_product_attention(
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
self,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
getattn=False,
) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)

View File

@@ -1004,7 +1004,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
"""
self.generator = generator
self.denoising_steps = num_inference_steps
self.guidance_scale = guidance_scale
self._guidance_scale = guidance_scale
# Pre-compute latent input scales and linear multistep coefficients
self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)

View File

@@ -94,7 +94,7 @@ accelerate launch train_lcm_distill_lora_sd_wds.py \
--mixed_precision=fp16 \
--resolution=512 \
--lora_rank=64 \
--learning_rate=1e-6 --loss_type="huber" --adam_weight_decay=0.0 \
--learning_rate=1e-4 --loss_type="huber" --adam_weight_decay=0.0 \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \

View File

@@ -96,7 +96,7 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
--mixed_precision=fp16 \
--resolution=1024 \
--lora_rank=64 \
--learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
--learning_rate=1e-4 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \

View File

@@ -65,7 +65,7 @@ class ControlNet(ExamplesTestsAccelerate):
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=9
--max_train_steps=6
--checkpointing_steps=2
""".split()
@@ -73,7 +73,7 @@ class ControlNet(ExamplesTestsAccelerate):
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
)
resume_run_args = f"""
@@ -85,18 +85,15 @@ class ControlNet(ExamplesTestsAccelerate):
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-6
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
class ControlNetSDXL(ExamplesTestsAccelerate):
@@ -111,7 +108,7 @@ class ControlNetSDXL(ExamplesTestsAccelerate):
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()

View File

@@ -76,10 +76,7 @@ class CustomDiffusion(ExamplesTestsAccelerate):
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
@@ -93,7 +90,7 @@ class CustomDiffusion(ExamplesTestsAccelerate):
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
--no_safe_serialization
""".split()
@@ -102,7 +99,7 @@ class CustomDiffusion(ExamplesTestsAccelerate):
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)
resume_run_args = f"""
@@ -115,16 +112,13 @@ class CustomDiffusion(ExamplesTestsAccelerate):
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--no_safe_serialization
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

View File

@@ -89,7 +89,7 @@ class DreamBooth(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
@@ -100,7 +100,7 @@ class DreamBooth(ExamplesTestsAccelerate):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -114,7 +114,7 @@ class DreamBooth(ExamplesTestsAccelerate):
# check can run the original fully trained output pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
@@ -123,7 +123,7 @@ class DreamBooth(ExamplesTestsAccelerate):
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
@@ -138,7 +138,7 @@ class DreamBooth(ExamplesTestsAccelerate):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -153,7 +153,7 @@ class DreamBooth(ExamplesTestsAccelerate):
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)
# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
@@ -196,7 +196,7 @@ class DreamBooth(ExamplesTestsAccelerate):
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()
@@ -204,7 +204,7 @@ class DreamBooth(ExamplesTestsAccelerate):
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)
resume_run_args = f"""
@@ -216,15 +216,12 @@ class DreamBooth(ExamplesTestsAccelerate):
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

View File

@@ -135,16 +135,13 @@ class DreamBoothLoRA(ExamplesTestsAccelerate):
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
examples/dreambooth/train_dreambooth_lora.py
@@ -155,18 +152,15 @@ class DreamBoothLoRA(ExamplesTestsAccelerate):
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_if_model(self):
with tempfile.TemporaryDirectory() as tmpdir:
@@ -328,7 +322,7 @@ class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--checkpointing_steps=2
--checkpoints_total_limit=2
--learning_rate 5.0e-04
@@ -342,14 +336,11 @@ class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe("a prompt", num_inference_steps=2)
pipe("a prompt", num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
# checkpoint-2 should have been deleted
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

View File

@@ -64,39 +64,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -860,6 +827,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
@@ -868,7 +836,10 @@ def main(args):
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder.add_adapter(text_lora_config)

View File

@@ -64,39 +64,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -1011,7 +978,10 @@ def main(args):
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet.add_adapter(unet_lora_config)
@@ -1019,11 +989,25 @@ def main(args):
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
@@ -1166,10 +1150,26 @@ def main(args):
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,

View File

@@ -40,7 +40,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=7
--max_train_steps=6
--checkpointing_steps=2
--checkpoints_total_limit=2
--output_dir {tmpdir}
@@ -63,7 +63,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
@@ -74,7 +74,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)
resume_run_args = f"""
@@ -84,12 +84,12 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
@@ -97,5 +97,5 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
{"checkpoint-6", "checkpoint-8"},
)

View File

@@ -1,6 +1,6 @@
diffusers==0.20.1
accelerate==0.23.0
transformers==4.34.0
transformers==4.36.0
peft==0.5.0
torch==2.0.1
torchvision>=0.16

View File

@@ -101,8 +101,8 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
```python
import torch
from diffusers import StableDiffusionPipeline
model_path = "path_to_saved_model"
@@ -114,12 +114,13 @@ image.save("yoda-pokemon.png")
```
Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
```python
import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
model_path = "path_to_saved_model"
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet")
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet", torch_dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained("<initial model>", unet=unet, torch_dtype=torch.float16)
pipe.to("cuda")

View File

@@ -64,7 +64,7 @@ class TextToImage(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
@@ -76,7 +76,7 @@ class TextToImage(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -89,7 +89,7 @@ class TextToImage(ExamplesTestsAccelerate):
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
@@ -100,12 +100,12 @@ class TextToImage(ExamplesTestsAccelerate):
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
# Run training script for 2 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/text_to_image/train_text_to_image.py
@@ -116,13 +116,13 @@ class TextToImage(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpointing_steps=1
--resume_from_checkpoint=checkpoint-4
--seed=0
""".split()
@@ -131,16 +131,13 @@ class TextToImage(ExamplesTestsAccelerate):
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist
"checkpoint-4",
"checkpoint-6",
},
{"checkpoint-4", "checkpoint-5"},
)
def test_text_to_image_checkpointing_use_ema(self):
@@ -149,7 +146,7 @@ class TextToImage(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
@@ -161,7 +158,7 @@ class TextToImage(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -186,12 +183,12 @@ class TextToImage(ExamplesTestsAccelerate):
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
# Run training script for 2 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/text_to_image/train_text_to_image.py
@@ -202,13 +199,13 @@ class TextToImage(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpointing_steps=1
--resume_from_checkpoint=checkpoint-4
--use_ema
--seed=0
@@ -218,16 +215,13 @@ class TextToImage(ExamplesTestsAccelerate):
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist
"checkpoint-4",
"checkpoint-6",
},
{"checkpoint-4", "checkpoint-5"},
)
def test_text_to_image_checkpointing_checkpoints_total_limit(self):
@@ -236,7 +230,7 @@ class TextToImage(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
@@ -249,7 +243,7 @@ class TextToImage(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -263,14 +257,11 @@ class TextToImage(ExamplesTestsAccelerate):
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
# checkpoint-2 should have been deleted
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
@@ -278,8 +269,8 @@ class TextToImage(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 9, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4, 6, 8
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
@@ -290,7 +281,7 @@ class TextToImage(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -303,15 +294,15 @@ class TextToImage(ExamplesTestsAccelerate):
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)
# resume and we should try to checkpoint at 10, where we'll have to remove
# resume and we should try to checkpoint at 6, where we'll have to remove
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
resume_run_args = f"""
@@ -323,27 +314,27 @@ class TextToImage(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 11
--max_train_steps 8
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
{"checkpoint-6", "checkpoint-8"},
)

View File

@@ -41,7 +41,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
@@ -52,7 +52,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -66,14 +66,11 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
# checkpoint-2 should have been deleted
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
@@ -81,7 +78,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
@@ -94,7 +91,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -112,14 +109,11 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
# checkpoint-2 should have been deleted
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
@@ -127,8 +121,8 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 9, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4, 6, 8
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
@@ -139,7 +133,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -156,15 +150,15 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)
# resume and we should try to checkpoint at 10, where we'll have to remove
# resume and we should try to checkpoint at 6, where we'll have to remove
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
resume_run_args = f"""
@@ -176,15 +170,15 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 11
--max_train_steps 8
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--seed=0
--num_validation_images=0
""".split()
@@ -195,12 +189,12 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
{"checkpoint-6", "checkpoint-8"},
)
@@ -272,7 +266,7 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
@@ -283,7 +277,7 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -298,11 +292,8 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
pipe(prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
# checkpoint-2 should have been deleted
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})

View File

@@ -54,39 +54,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
@@ -485,7 +452,10 @@ def main():
param.requires_grad_(False)
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
# Move unet, vae and text_encoder to device and cast to weight_dtype
@@ -493,7 +463,13 @@ def main():
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Add adapter and make sure the trainable params are in float32.
unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
for param in unet.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@@ -832,7 +808,8 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
@@ -870,10 +847,11 @@ def main():
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
with torch.cuda.amp.autocast():
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
@@ -897,7 +875,8 @@ def main():
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
@@ -919,40 +898,46 @@ def main():
ignore_patterns=["step_*", "epoch_*"],
)
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
pipeline = pipeline.to(accelerator.device)
# Final inference
# Load previous pipeline
if args.validation_prompt is not None:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
generator = torch.Generator(device=accelerator.device)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
if accelerator.is_main_process:
for tracker in accelerator.trackers:
if len(images) != 0:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
# run inference
generator = torch.Generator(device=accelerator.device)
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
with torch.cuda.amp.autocast():
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
for tracker in accelerator.trackers:
if len(images) != 0:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
accelerator.end_training()

View File

@@ -22,7 +22,6 @@ import os
import random
import shutil
from pathlib import Path
from typing import Dict
import datasets
import numpy as np
@@ -63,39 +62,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -469,22 +435,6 @@ DATASET_NAME_MAPPING = {
}
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors
attn_processors_state_dict = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
return attn_processors_state_dict
def tokenize_prompt(tokenizer, prompt):
text_inputs = tokenizer(
prompt,
@@ -659,7 +609,10 @@ def main(args):
# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet.add_adapter(unet_lora_config)
@@ -668,11 +621,25 @@ def main(args):
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
@@ -1220,6 +1187,9 @@ def main(args):
torch.cuda.empty_cache()
# Final inference
# Make sure vae.dtype is consistent with the unet.dtype
if args.mixed_precision == "fp16":
vae.to(weight_dtype)
# Load previous pipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,

View File

@@ -40,8 +40,6 @@ class TextualInversion(ExamplesTestsAccelerate):
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
@@ -68,8 +66,6 @@ class TextualInversion(ExamplesTestsAccelerate):
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
@@ -102,14 +98,12 @@ class TextualInversion(ExamplesTestsAccelerate):
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 3
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
@@ -123,7 +117,7 @@ class TextualInversion(ExamplesTestsAccelerate):
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-1", "checkpoint-2", "checkpoint-3"},
{"checkpoint-1", "checkpoint-2"},
)
resume_run_args = f"""
@@ -133,21 +127,19 @@ class TextualInversion(ExamplesTestsAccelerate):
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=1
--resume_from_checkpoint=checkpoint-3
--resume_from_checkpoint=checkpoint-2
--checkpoints_total_limit=2
""".split()
@@ -156,5 +148,5 @@ class TextualInversion(ExamplesTestsAccelerate):
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-3", "checkpoint-4"},
{"checkpoint-2", "checkpoint-3"},
)

View File

@@ -90,10 +90,10 @@ class Unconditional(ExamplesTestsAccelerate):
--train_batch_size 1
--num_epochs 1
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--ddpm_num_inference_steps 1
--learning_rate 1e-3
--lr_warmup_steps 5
--checkpointing_steps=1
--checkpointing_steps=2
""".split()
run_command(self._launch_args + initial_run_args)
@@ -101,7 +101,7 @@ class Unconditional(ExamplesTestsAccelerate):
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
)
resume_run_args = f"""
@@ -113,12 +113,12 @@ class Unconditional(ExamplesTestsAccelerate):
--train_batch_size 1
--num_epochs 2
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--ddpm_num_inference_steps 1
--learning_rate 1e-3
--lr_warmup_steps 5
--resume_from_checkpoint=checkpoint-6
--checkpointing_steps=2
--checkpoints_total_limit=3
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
@@ -126,5 +126,5 @@ class Unconditional(ExamplesTestsAccelerate):
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
{"checkpoint-10", "checkpoint-12"},
)

View File

@@ -77,7 +77,7 @@ First, you need to set up your development environment as explained in the [inst
```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image_prior_lora.py \
accelerate launch train_text_to_image_lora_prior.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME --caption_column="text" \
--resolution=768 \

523
scripts/convert_amused.py Normal file
View File

@@ -0,0 +1,523 @@
import inspect
import os
from argparse import ArgumentParser
import numpy as np
import torch
from muse import MaskGiTUViT, VQGANModel
from muse import PipelineMuse as OldPipelineMuse
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import VQModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.uvit_2d import UVit2DModel
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
from diffusers.schedulers import AmusedScheduler
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
device = "cuda"
def main():
args = ArgumentParser()
args.add_argument("--model_256", action="store_true")
args.add_argument("--write_to", type=str, required=False, default=None)
args.add_argument("--transformer_path", type=str, required=False, default=None)
args = args.parse_args()
transformer_path = args.transformer_path
subfolder = "transformer"
if transformer_path is None:
if args.model_256:
transformer_path = "openMUSE/muse-256"
else:
transformer_path = (
"../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/"
)
subfolder = None
old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)
old_transformer.to(device)
old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae")
old_vae.to(device)
vqvae = make_vqvae(old_vae)
tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
text_encoder.to(device)
transformer = make_transformer(old_transformer, args.model_256)
scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)
new_pipe = AmusedPipeline(
vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler
)
old_pipe = OldPipelineMuse(
vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer
)
old_pipe.to(device)
if args.model_256:
transformer_seq_len = 256
orig_size = (256, 256)
else:
transformer_seq_len = 1024
orig_size = (512, 512)
old_out = old_pipe(
"dog",
generator=torch.Generator(device).manual_seed(0),
transformer_seq_len=transformer_seq_len,
orig_size=orig_size,
timesteps=12,
)[0]
new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0]
old_out = np.array(old_out)
new_out = np.array(new_out)
diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))
# assert diff diff.sum() == 0
print("skipping pipeline full equivalence check")
print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}")
if args.model_256:
assert diff.max() <= 3
assert diff.sum() / diff.size < 0.7
else:
assert diff.max() <= 1
assert diff.sum() / diff.size < 0.4
if args.write_to is not None:
new_pipe.save_pretrained(args.write_to)
def make_transformer(old_transformer, model_256):
args = dict(old_transformer.config)
force_down_up_sample = args["force_down_up_sample"]
signature = inspect.signature(UVit2DModel.__init__)
args_ = {
"downsample": force_down_up_sample,
"upsample": force_down_up_sample,
"block_out_channels": args["block_out_channels"][0],
"sample_size": 16 if model_256 else 32,
}
for s in list(signature.parameters.keys()):
if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]:
continue
args_[s] = args[s]
new_transformer = UVit2DModel(**args_)
new_transformer.to(device)
new_transformer.set_attn_processor(AttnProcessor())
state_dict = old_transformer.state_dict()
state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight")
state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight")
for i in range(22):
state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop(
f"transformer_layers.{i}.attn_layer_norm.weight"
)
state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop(
f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight"
)
state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop(
f"transformer_layers.{i}.attention.query.weight"
)
state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop(
f"transformer_layers.{i}.attention.key.weight"
)
state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop(
f"transformer_layers.{i}.attention.value.weight"
)
state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop(
f"transformer_layers.{i}.attention.out.weight"
)
state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattn_layer_norm.weight"
)
state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop(
f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight"
)
state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattention.query.weight"
)
state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattention.key.weight"
)
state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattention.value.weight"
)
state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattention.out.weight"
)
state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop(
f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight"
)
state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop(
f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight"
)
wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight")
wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight")
proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)
state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight
state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight")
if force_down_up_sample:
state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight")
state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight")
state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight")
state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight")
state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight")
for i in range(3):
state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.norm.norm.weight"
)
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.channelwise.0.weight"
)
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma"
)
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.channelwise.2.beta"
)
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.channelwise.4.weight"
)
state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
)
state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attention.query.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attention.key.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attention.value.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attention.out.weight"
)
state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight"
)
state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.norm.norm.weight"
)
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.channelwise.0.weight"
)
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma"
)
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.channelwise.2.beta"
)
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.channelwise.4.weight"
)
state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
)
state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attention.query.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attention.key.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attention.value.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attention.out.weight"
)
state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight"
)
for key in list(state_dict.keys()):
if key.startswith("up_blocks.0"):
key_ = "up_block." + ".".join(key.split(".")[2:])
state_dict[key_] = state_dict.pop(key)
if key.startswith("down_blocks.0"):
key_ = "down_block." + ".".join(key.split(".")[2:])
state_dict[key_] = state_dict.pop(key)
new_transformer.load_state_dict(state_dict)
input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)
encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)
cond_embeds = torch.randn((1, 768), device=old_transformer.device)
micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)
old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)
old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)
new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)
# NOTE: these differences are solely due to using the geglu block that has a single linear layer of
# double output dimension instead of two different linear layers
max_diff = (old_out - new_out).abs().max()
total_diff = (old_out - new_out).abs().sum()
print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}")
assert max_diff < 0.01
assert total_diff < 1500
return new_transformer
def make_vqvae(old_vae):
new_vae = VQModel(
act_fn="silu",
block_out_channels=[128, 256, 256, 512, 768],
down_block_types=[
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
in_channels=3,
latent_channels=64,
layers_per_block=2,
norm_num_groups=32,
num_vq_embeddings=8192,
out_channels=3,
sample_size=32,
up_block_types=[
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
],
mid_block_add_attention=False,
lookup_from_codebook=True,
)
new_vae.to(device)
# fmt: off
new_state_dict = {}
old_state_dict = old_vae.state_dict()
new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")
new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias")
new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight")
new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias")
new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight")
new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias")
new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight")
new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias")
new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight")
new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias")
new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight")
new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight")
new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias")
new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight")
new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias")
new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight")
new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias")
new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight")
new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias")
new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight")
new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias")
new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight")
new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias")
new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight")
new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias")
new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight")
new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias")
new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight")
new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias")
new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight")
new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias")
convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4")
convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3")
convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2")
convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1")
convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0")
new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight")
new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias")
new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight")
new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias")
# fmt: on
assert len(old_state_dict.keys()) == 0
new_vae.load_state_dict(new_state_dict)
input = torch.randn((1, 3, 512, 512), device=device)
input = input.clamp(-1, 1)
old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))
new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))
assert (old_encoder_output == new_encoder_output).all()
old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))
new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))
# assert (old_decoder_output == new_decoder_output).all()
print("kipping vae decoder equivalence check")
print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}")
old_output = old_vae(input)[0]
new_output = new_vae(input)[0]
# assert (old_output == new_output).all()
print("skipping full vae equivalence check")
print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
return new_vae
def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):
# fmt: off
new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight")
new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias")
new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight")
new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias")
new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight")
new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias")
new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight")
new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias")
if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict:
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight")
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias")
new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight")
new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias")
new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight")
new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias")
new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight")
new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias")
new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight")
new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias")
if f"{prefix_from}.downsample.conv.weight" in old_state_dict:
new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight")
new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias")
if f"{prefix_from}.upsample.conv.weight" in old_state_dict:
new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight")
new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias")
if f"{prefix_from}.block.2.norm1.weight" in old_state_dict:
new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight")
new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias")
new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight")
new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias")
new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight")
new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias")
new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight")
new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias")
# fmt: on
if __name__ == "__main__":
main()

View File

@@ -12,9 +12,9 @@ from safetensors.torch import load_file as stl
from tqdm import tqdm
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
from diffusers.models.autoencoders.vae import Encoder
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
from diffusers.models.vae import Encoder
args = ArgumentParser()

View File

@@ -159,6 +159,14 @@ vae_conversion_map_attn = [
("proj_out.", "proj_attn."),
]
# This is probably not the most ideal solution, but it does work.
vae_extra_conversion_map = [
("to_q", "q"),
("to_k", "k"),
("to_v", "v"),
("to_out.0", "proj_out"),
]
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
@@ -178,11 +186,20 @@ def convert_vae_state_dict(vae_state_dict):
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
keys_to_rename = {}
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
for weight_name, real_weight_name in vae_extra_conversion_map:
if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
keys_to_rename[k] = k.replace(weight_name, real_weight_name)
for k, v in keys_to_rename.items():
if k in new_state_dict:
print(f"Renaming {k} to {v}")
new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
del new_state_dict[k]
return new_state_dict

View File

@@ -95,6 +95,7 @@ else:
"UNet3DConditionModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
]
)
@@ -131,6 +132,7 @@ else:
)
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
"CMStochasticIterativeScheduler",
"DDIMInverseScheduler",
"DDIMParallelScheduler",
@@ -202,6 +204,9 @@ else:
[
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AmusedImg2ImgPipeline",
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
@@ -472,6 +477,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
)
from .optimization import (
@@ -506,6 +512,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ScoreSdeVePipeline,
)
from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler,
DDIMInverseScheduler,
DDIMParallelScheduler,
@@ -560,6 +567,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AmusedImg2ImgPipeline,
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,

View File

@@ -18,6 +18,7 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from torch import nn
@@ -58,6 +59,7 @@ logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
@@ -73,6 +75,7 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
transformer_name = TRANSFORMER_NAME
num_fused_loras = 0
def load_lora_weights(
@@ -229,7 +232,9 @@ class LoraLoaderMixin:
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
@@ -255,7 +260,7 @@ class LoraLoaderMixin:
if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin"
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
@@ -294,7 +299,12 @@ class LoraLoaderMixin:
return state_dict, network_alphas
@classmethod
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
@@ -653,6 +663,89 @@ class LoraLoaderMixin:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
network_alphas = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
if len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
@@ -778,6 +871,7 @@ class LoraLoaderMixin:
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -812,8 +906,10 @@ class LoraLoaderMixin:
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
@@ -821,6 +917,9 @@ class LoraLoaderMixin:
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,

View File

@@ -169,10 +169,12 @@ class FromSingleFileMixin:
load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None)
text_encoder_2 = kwargs.pop("text_encoder_2", None)
vae = kwargs.pop("vae", None)
controlnet = kwargs.pop("controlnet", None)
adapter = kwargs.pop("adapter", None)
tokenizer = kwargs.pop("tokenizer", None)
tokenizer_2 = kwargs.pop("tokenizer_2", None)
torch_dtype = kwargs.pop("torch_dtype", None)
@@ -274,8 +276,10 @@ class FromSingleFileMixin:
load_safety_checker=load_safety_checker,
prediction_type=prediction_type,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
original_config_file=original_config_file,
config_files=config_files,
local_files_only=local_files_only,

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict, defaultdict
from collections import defaultdict
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
@@ -664,6 +664,80 @@ class UNet2DConditionLoadersMixin:
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
updated_state_dict = {}
image_projection = None
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
elif "proj.3.weight" in state_dict:
# IP-Adapter Full
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
image_projection = MLPProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
diffusers_name = diffusers_name.replace("proj.3", "norm")
updated_state_dict[diffusers_name] = value
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["latents"].shape[1]
embed_dims = state_dict["proj_in.weight"].shape[1]
output_dims = state_dict["proj_out.weight"].shape[0]
hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("0.to", "2.to")
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
elif "norm2" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
elif "to_kv" in diffusers_name:
v_chunk = value.chunk(2, dim=0)
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in diffusers_name:
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
else:
updated_state_dict[diffusers_name] = value
image_projection.load_state_dict(updated_state_dict)
return image_projection
def _load_ip_adapter_weights(self, state_dict):
from ..models.attention_processor import (
AttnProcessor,
@@ -724,103 +798,8 @@ class UNet2DConditionLoadersMixin:
self.set_attn_processor(attn_procs)
# create image projection layers.
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
image_projection.to(dtype=self.dtype, device=self.device)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
elif "proj.3.weight" in state_dict["image_proj"]:
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0]
image_projection = MLPProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
image_projection.to(dtype=self.dtype, device=self.device)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
else:
# IP-Adapter Plus
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)
image_proj_state_dict = state_dict["image_proj"]
new_sd = OrderedDict()
for k, v in image_proj_state_dict.items():
if "0.to" in k:
k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
k = k.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in k:
new_sd[k.replace("0.norm1", "0")] = v
elif "norm2" in k:
new_sd[k.replace("0.norm2", "1")] = v
elif "to_kv" in k:
v_chunk = v.chunk(2, dim=0)
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in k:
new_sd[k.replace("to_out", "to_out.0")] = v
else:
new_sd[k] = v
image_projection.load_state_dict(new_sd)
del image_proj_state_dict
# convert IP-Adapter Image Projection layers to diffusers
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
delete_adapter_layers

View File

@@ -26,11 +26,11 @@ _import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
@@ -47,6 +47,7 @@ if is_torch_available():
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["uvit_2d"] = ["UVit2DModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -58,11 +59,13 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
from .controlnetxs import ControlNetXSModel
from .dual_transformer_2d import DualTransformer2DModel
@@ -79,6 +82,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .uvit_2d import UVit2DModel
from .vq_model import VQModel
if is_flax_available():

View File

@@ -14,6 +14,7 @@
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import USE_PEFT_BACKEND
@@ -22,7 +23,7 @@ from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
def _chunked_feed_forward(
@@ -148,6 +149,11 @@ class BasicTransformerBlock(nn.Module):
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.only_cross_attention = only_cross_attention
@@ -156,6 +162,7 @@ class BasicTransformerBlock(nn.Module):
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
@@ -179,6 +186,15 @@ class BasicTransformerBlock(nn.Module):
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_continuous:
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
@@ -190,6 +206,7 @@ class BasicTransformerBlock(nn.Module):
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
@@ -197,11 +214,20 @@ class BasicTransformerBlock(nn.Module):
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
)
if self.use_ada_layer_norm:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_continuous:
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
@@ -210,20 +236,32 @@ class BasicTransformerBlock(nn.Module):
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
if self.use_ada_layer_norm_continuous:
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
@@ -252,6 +290,7 @@ class BasicTransformerBlock(nn.Module):
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
@@ -265,6 +304,8 @@ class BasicTransformerBlock(nn.Module):
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
@@ -314,6 +355,8 @@ class BasicTransformerBlock(nn.Module):
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
@@ -329,7 +372,9 @@ class BasicTransformerBlock(nn.Module):
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if not self.use_ada_layer_norm_single:
if self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
@@ -490,6 +535,78 @@ class TemporalBasicTransformerBlock(nn.Module):
return hidden_states
class SkipFFTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
kv_input_dim: int,
kv_input_dim_proj_use_bias: bool,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
):
super().__init__()
if kv_input_dim != dim:
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
else:
self.kv_mapper = None
self.norm1 = RMSNorm(dim, 1e-06)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
out_bias=attention_out_bias,
)
self.norm2 = RMSNorm(dim, 1e-06)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
if self.kv_mapper is not None:
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
@@ -512,10 +629,12 @@ class FeedForward(nn.Module):
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
inner_dim = int(dim * mult)
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

View File

@@ -0,0 +1,5 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE

View File

@@ -16,10 +16,10 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder

View File

@@ -16,10 +16,10 @@ from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
@@ -27,8 +27,8 @@ from .attention_processor import (
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder

View File

@@ -16,14 +16,14 @@ from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import is_torch_version
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder

View File

@@ -18,10 +18,10 @@ from typing import Optional, Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DecoderTiny, EncoderTiny

View File

@@ -18,20 +18,20 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..schedulers import ConsistencyDecoderScheduler
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from ..utils.torch_utils import randn_tensor
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...schedulers import ConsistencyDecoderScheduler
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_utils import ModelMixin
from .unet_2d import UNet2DModel
from ..modeling_utils import ModelMixin
from ..unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -153,7 +153,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.use_slicing = False
self.use_tiling = False
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -162,7 +162,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
"""
self.use_tiling = use_tiling
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
@@ -170,7 +170,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
"""
self.enable_tiling(False)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
@@ -178,7 +178,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
"""
self.use_slicing = True
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return DecoderOutput(sample=x_0)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):

View File

@@ -18,11 +18,11 @@ import numpy as np
import torch
import torch.nn as nn
from ..utils import BaseOutput, is_torch_version
from ..utils.torch_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
from .unet_2d_blocks import (
from ...utils import BaseOutput, is_torch_version
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
from ..unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
@@ -77,6 +77,7 @@ class Encoder(nn.Module):
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
@@ -124,6 +125,7 @@ class Encoder(nn.Module):
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
@@ -213,6 +215,7 @@ class Decoder(nn.Module):
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
@@ -240,6 +243,7 @@ class Decoder(nn.Module):
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up

View File

@@ -23,10 +23,8 @@ from torch.nn.modules.normalization import GroupNorm
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import (
AttentionProcessor,
)
from .autoencoder_kl import AutoencoderKL
from .attention_processor import USE_PEFT_BACKEND, AttentionProcessor
from .autoencoders import AutoencoderKL
from .lora import LoRACompatibleConv
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
@@ -817,11 +815,23 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
norm_kwargs["num_channels"] += by # surgery done here
# conv1
conv1_args = (
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
)
conv1_args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
conv1_args.append("lora_layer")
for a in conv1_args:
assert hasattr(old_conv1, a)
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
conv1_kwargs["in_channels"] += by # surgery done here
@@ -839,25 +849,42 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
}
# swap old with new modules
unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = LoRACompatibleConv(**conv1_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
)
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
)
unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
"""Increase channels sizes to allow for additional concatted information from base model"""
old_down = unet.down_blocks[block_no].downsamplers[0].conv
# conv1
args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(
" "
)
args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
args.append("lora_layer")
for a in args:
assert hasattr(old_down, a)
kwargs = {a: getattr(old_down, a) for a in args}
kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
kwargs["in_channels"] += by # surgery done here
# swap old with new modules
unet.down_blocks[block_no].downsamplers[0].conv = LoRACompatibleConv(**kwargs)
unet.down_blocks[block_no].downsamplers[0].conv = (
nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
)
unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
@@ -871,12 +898,20 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
assert hasattr(old_norm1, a)
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
norm_kwargs["num_channels"] += by # surgery done here
# conv1
conv1_args = (
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
)
for a in conv1_args:
assert hasattr(old_conv1, a)
conv1_args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
conv1_args.append("lora_layer")
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
conv1_kwargs["in_channels"] += by # surgery done here
@@ -894,8 +929,12 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
}
# swap old with new modules
unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs)
unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
unet.mid_block.resnets[0].conv1 = (
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
)
unet.mid_block.resnets[0].conv_shortcut = (
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
)
unet.mid_block.resnets[0].in_channels += by # surgery done here

View File

@@ -0,0 +1,338 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleConv
from .normalization import RMSNorm
from .upsampling import upfirdn2d_native
class Downsample1D(nn.Module):
"""A 1D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 1D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
return self.conv(inputs)
class Downsample2D(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
kernel_size=3,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
conv = conv_cls(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
if not USE_PEFT_BACKEND:
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.conv(hidden_states)
return hidden_states
class FirDownsample2D(nn.Module):
"""A 2D FIR downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
kernel for the FIR filter.
"""
def __init__(
self,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def _downsample_2d(
self,
hidden_states: torch.FloatTensor,
weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
# setup kernel
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
if self.use_conv:
_, _, convH, convW = weight.shape
pad_value = (kernel.shape[0] - factor) + (convW - 1)
stride_value = [factor, factor]
upfirdn_input = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
pad=((pad_value + 1) // 2, pad_value // 2),
)
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
return hidden_states
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
r"""A 2D K-downsampling layer.
Parameters:
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
"""
def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv2d(inputs, weight, stride=2)
def downsample_2d(
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
Args:
hidden_states (`torch.FloatTensor`)
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output

View File

@@ -197,11 +197,12 @@ class TimestepEmbedding(nn.Module):
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_1 = linear_cls(in_channels, time_embed_dim)
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -214,7 +215,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
@@ -729,7 +730,7 @@ class PositionNet(nn.Module):
return objs
class CombinedTimestepSizeEmbeddings(nn.Module):
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
@@ -746,45 +747,27 @@ class CombinedTimestepSizeEmbeddings(nn.Module):
self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.use_additional_conditions = True
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
if size.ndim == 1:
size = size[:, None]
if size.shape[0] != batch_size:
size = size.repeat(batch_size // size.shape[0], 1)
if size.shape[0] != batch_size:
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
current_batch_size, dims = size.shape[0], size.shape[1]
size = size.reshape(-1)
size_freq = self.additional_condition_proj(size).to(size.dtype)
size_emb = embedder(size_freq)
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
return size_emb
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions:
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
aspect_ratio = self.apply_condition(
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
)
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb
return conditioning
class CaptionProjection(nn.Module):
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
@@ -796,9 +779,8 @@ class CaptionProjection(nn.Module):
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
def forward(self, caption, force_drop_ids=None):
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)

View File

@@ -13,14 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import is_torch_version
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
class AdaLayerNorm(nn.Module):
@@ -91,7 +93,7 @@ class AdaLayerNormSingle(nn.Module):
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__()
self.emb = CombinedTimestepSizeEmbeddings(
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)
@@ -146,3 +148,107 @@ class AdaGroupNorm(nn.Module):
x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift
return x
class AdaLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
def forward(self, input):
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class GlobalResponseNorm(nn.Module):
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * nx) + self.beta + x

View File

@@ -23,562 +23,23 @@ import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .attention_processor import SpatialNorm
from .downsampling import ( # noqa
Downsample1D,
Downsample2D,
FirDownsample2D,
KDownsample2D,
downsample_2d,
)
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .normalization import AdaGroupNorm
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 1D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs
class Downsample1D(nn.Module):
"""A 1D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 1D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
return self.conv(inputs)
class Upsample2D(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
conv = None
if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
class Downsample2D(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if use_conv:
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
if not USE_PEFT_BACKEND:
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.conv(hidden_states)
return hidden_states
class FirUpsample2D(nn.Module):
"""A 2D FIR upsampling layer with an optional convolution.
Parameters:
channels (`int`, optional):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
kernel for the FIR filter.
"""
def __init__(
self,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def _upsample_2d(
self,
hidden_states: torch.FloatTensor,
weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
"""Fused `upsample_2d()` followed by `Conv2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to nearest-neighbor upsampling.
factor (`int`, *optional*): Integer upsampling factor (default: 2).
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as `hidden_states`.
"""
assert isinstance(factor, int) and factor >= 1
# Setup filter kernel.
if kernel is None:
kernel = [1] * factor
# setup kernel
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
if self.use_conv:
convH = weight.shape[2]
convW = weight.shape[3]
inC = weight.shape[1]
pad_value = (kernel.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
output_shape = (
(hidden_states.shape[2] - 1) * factor + convH,
(hidden_states.shape[3] - 1) * factor + convW,
)
output_padding = (
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = hidden_states.shape[1] // inC
# Transpose weights.
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d(
hidden_states,
weight,
stride=stride,
output_padding=output_padding,
padding=0,
)
output = upfirdn2d_native(
inverse_conv,
torch.tensor(kernel, device=inverse_conv.device),
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
return output
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv:
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
return height
class FirDownsample2D(nn.Module):
"""A 2D FIR downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
kernel for the FIR filter.
"""
def __init__(
self,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def _downsample_2d(
self,
hidden_states: torch.FloatTensor,
weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
# setup kernel
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
if self.use_conv:
_, _, convH, convW = weight.shape
pad_value = (kernel.shape[0] - factor) + (convW - 1)
stride_value = [factor, factor]
upfirdn_input = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
pad=((pad_value + 1) // 2, pad_value // 2),
)
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
return hidden_states
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
r"""A 2D K-downsampling layer.
Parameters:
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
"""
def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv2d(inputs, weight, stride=2)
class KUpsample2D(nn.Module):
r"""A 2D K-upsampling layer.
Parameters:
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
"""
def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
from .upsampling import ( # noqa
FirUpsample2D,
KUpsample2D,
Upsample1D,
Upsample2D,
upfirdn2d_native,
upsample_2d,
)
class ResnetBlock2D(nn.Module):
@@ -894,151 +355,6 @@ class ResidualTemporalBlock1D(nn.Module):
return out + self.residual_conv(inputs)
def upsample_2d(
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
a: multiple of the upsampling factor.
Args:
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to nearest-neighbor upsampling.
factor (`int`, *optional*, default to `2`):
Integer upsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude (default: 1.0).
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
return output
def downsample_2d(
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
Args:
hidden_states (`torch.FloatTensor`)
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
def upfirdn2d_native(
tensor: torch.Tensor,
kernel: torch.Tensor,
up: int = 1,
down: int = 1,
pad: Tuple[int, int] = (0, 0),
) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
pad_x1 = pad_y1 = pad[1]
_, channel, in_h, in_w = tensor.shape
tensor = tensor.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = tensor.shape
kernel_h, kernel_w = kernel.shape
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out.to(tensor.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
class TemporalConvLayer(nn.Module):
"""
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:

View File

@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from .attention import BasicTransformerBlock
from .embeddings import CaptionProjection, PatchEmbed
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
from .normalization import AdaLayerNormSingle
@@ -235,7 +235,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False

View File

@@ -0,0 +1,454 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleConv
from .normalization import RMSNorm
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 1D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs
class Upsample2D(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = None,
padding=1,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
interpolate=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
conv = None
if use_conv_transpose:
if kernel_size is None:
kernel_size = 4
conv = nn.ConvTranspose2d(
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
)
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if self.use_conv_transpose:
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if self.interpolate:
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
class FirUpsample2D(nn.Module):
"""A 2D FIR upsampling layer with an optional convolution.
Parameters:
channels (`int`, optional):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
kernel for the FIR filter.
"""
def __init__(
self,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def _upsample_2d(
self,
hidden_states: torch.FloatTensor,
weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
"""Fused `upsample_2d()` followed by `Conv2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to nearest-neighbor upsampling.
factor (`int`, *optional*): Integer upsampling factor (default: 2).
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as `hidden_states`.
"""
assert isinstance(factor, int) and factor >= 1
# Setup filter kernel.
if kernel is None:
kernel = [1] * factor
# setup kernel
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
if self.use_conv:
convH = weight.shape[2]
convW = weight.shape[3]
inC = weight.shape[1]
pad_value = (kernel.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
output_shape = (
(hidden_states.shape[2] - 1) * factor + convH,
(hidden_states.shape[3] - 1) * factor + convW,
)
output_padding = (
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = hidden_states.shape[1] // inC
# Transpose weights.
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d(
hidden_states,
weight,
stride=stride,
output_padding=output_padding,
padding=0,
)
output = upfirdn2d_native(
inverse_conv,
torch.tensor(kernel, device=inverse_conv.device),
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
return output
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv:
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
return height
class KUpsample2D(nn.Module):
r"""A 2D K-upsampling layer.
Parameters:
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
"""
def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
def upfirdn2d_native(
tensor: torch.Tensor,
kernel: torch.Tensor,
up: int = 1,
down: int = 1,
pad: Tuple[int, int] = (0, 0),
) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
pad_x1 = pad_y1 = pad[1]
_, channel, in_h, in_w = tensor.shape
tensor = tensor.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = tensor.shape
kernel_h, kernel_w = kernel.shape
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out.to(tensor.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
def upsample_2d(
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
a: multiple of the upsampling factor.
Args:
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to nearest-neighbor upsampling.
factor (`int`, *optional*, default to `2`):
Integer upsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude (default: 1.0).
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
return output

View File

@@ -0,0 +1,471 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from .attention import BasicTransformerBlock, SkipFFTransformerBlock
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, get_timestep_embedding
from .modeling_utils import ModelMixin
from .normalization import GlobalResponseNorm, RMSNorm
from .resnet import Downsample2D, Upsample2D
class UVit2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
# global config
hidden_size: int = 1024,
use_bias: bool = False,
hidden_dropout: float = 0.0,
# conditioning dimensions
cond_embed_dim: int = 768,
micro_cond_encode_dim: int = 256,
micro_cond_embed_dim: int = 1280,
encoder_hidden_size: int = 768,
# num tokens
vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded
codebook_size: int = 8192,
# `UVit2DConvEmbed`
in_channels: int = 768,
block_out_channels: int = 768,
num_res_blocks: int = 3,
downsample: bool = False,
upsample: bool = False,
block_num_heads: int = 12,
# `TransformerLayer`
num_hidden_layers: int = 22,
num_attention_heads: int = 16,
# `Attention`
attention_dropout: float = 0.0,
# `FeedForward`
intermediate_size: int = 2816,
# `Norm`
layer_norm_eps: float = 1e-6,
ln_elementwise_affine: bool = True,
sample_size: int = 64,
):
super().__init__()
self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias)
self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine)
self.embed = UVit2DConvEmbed(
in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
)
self.cond_embed = TimestepEmbedding(
micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias
)
self.down_block = UVitBlock(
block_out_channels,
num_res_blocks,
hidden_size,
hidden_dropout,
ln_elementwise_affine,
layer_norm_eps,
use_bias,
block_num_heads,
attention_dropout,
downsample,
False,
)
self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine)
self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias)
self.transformer_layers = nn.ModuleList(
[
BasicTransformerBlock(
dim=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_dim=hidden_size // num_attention_heads,
dropout=hidden_dropout,
cross_attention_dim=hidden_size,
attention_bias=use_bias,
norm_type="ada_norm_continuous",
ada_norm_continous_conditioning_embedding_dim=hidden_size,
norm_elementwise_affine=ln_elementwise_affine,
norm_eps=layer_norm_eps,
ada_norm_bias=use_bias,
ff_inner_dim=intermediate_size,
ff_bias=use_bias,
attention_out_bias=use_bias,
)
for _ in range(num_hidden_layers)
]
)
self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine)
self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias)
self.up_block = UVitBlock(
block_out_channels,
num_res_blocks,
hidden_size,
hidden_dropout,
ln_elementwise_affine,
layer_norm_eps,
use_bias,
block_num_heads,
attention_dropout,
downsample=False,
upsample=upsample,
)
self.mlm_layer = ConvMlmLayer(
block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size
)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
pass
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
micro_cond_embeds = get_timestep_embedding(
micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1))
pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1)
pooled_text_emb = pooled_text_emb.to(dtype=self.dtype)
pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype)
hidden_states = self.embed(input_ids)
hidden_states = self.down_block(
hidden_states,
pooled_text_emb=pooled_text_emb,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
batch_size, channels, height, width = hidden_states.shape
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
hidden_states = self.project_to_hidden_norm(hidden_states)
hidden_states = self.project_to_hidden(hidden_states)
for layer in self.transformer_layers:
if self.training and self.gradient_checkpointing:
def layer_(*args):
return checkpoint(layer, *args)
else:
layer_ = layer
hidden_states = layer_(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs={"pooled_text_emb": pooled_text_emb},
)
hidden_states = self.project_from_hidden_norm(hidden_states)
hidden_states = self.project_from_hidden(hidden_states)
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
hidden_states = self.up_block(
hidden_states,
pooled_text_emb=pooled_text_emb,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
logits = self.mlm_layer(hidden_states)
return logits
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
class UVit2DConvEmbed(nn.Module):
def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
super().__init__()
self.embeddings = nn.Embedding(vocab_size, in_channels)
self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)
def forward(self, input_ids):
embeddings = self.embeddings(input_ids)
embeddings = self.layer_norm(embeddings)
embeddings = embeddings.permute(0, 3, 1, 2)
embeddings = self.conv(embeddings)
return embeddings
class UVitBlock(nn.Module):
def __init__(
self,
channels,
num_res_blocks: int,
hidden_size,
hidden_dropout,
ln_elementwise_affine,
layer_norm_eps,
use_bias,
block_num_heads,
attention_dropout,
downsample: bool,
upsample: bool,
):
super().__init__()
if downsample:
self.downsample = Downsample2D(
channels,
use_conv=True,
padding=0,
name="Conv2d_0",
kernel_size=2,
norm_type="rms_norm",
eps=layer_norm_eps,
elementwise_affine=ln_elementwise_affine,
bias=use_bias,
)
else:
self.downsample = None
self.res_blocks = nn.ModuleList(
[
ConvNextBlock(
channels,
layer_norm_eps,
ln_elementwise_affine,
use_bias,
hidden_dropout,
hidden_size,
)
for i in range(num_res_blocks)
]
)
self.attention_blocks = nn.ModuleList(
[
SkipFFTransformerBlock(
channels,
block_num_heads,
channels // block_num_heads,
hidden_size,
use_bias,
attention_dropout,
channels,
attention_bias=use_bias,
attention_out_bias=use_bias,
)
for _ in range(num_res_blocks)
]
)
if upsample:
self.upsample = Upsample2D(
channels,
use_conv_transpose=True,
kernel_size=2,
padding=0,
name="conv",
norm_type="rms_norm",
eps=layer_norm_eps,
elementwise_affine=ln_elementwise_affine,
bias=use_bias,
interpolate=False,
)
else:
self.upsample = None
def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs):
if self.downsample is not None:
x = self.downsample(x)
for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
x = res_block(x, pooled_text_emb)
batch_size, channels, height, width = x.shape
x = x.view(batch_size, channels, height * width).permute(0, 2, 1)
x = attention_block(
x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs
)
x = x.permute(0, 2, 1).view(batch_size, channels, height, width)
if self.upsample is not None:
x = self.upsample(x)
return x
class ConvNextBlock(nn.Module):
def __init__(
self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
):
super().__init__()
self.depthwise = nn.Conv2d(
channels,
channels,
kernel_size=3,
padding=1,
groups=channels,
bias=use_bias,
)
self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
self.channelwise_act = nn.GELU()
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
self.channelwise_dropout = nn.Dropout(hidden_dropout)
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
def forward(self, x, cond_embeds):
x_res = x
x = self.depthwise(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.channelwise_linear_1(x)
x = self.channelwise_act(x)
x = self.channelwise_norm(x)
x = self.channelwise_linear_2(x)
x = self.channelwise_dropout(x)
x = x.permute(0, 3, 1, 2)
x = x + x_res
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
return x
class ConvMlmLayer(nn.Module):
def __init__(
self,
block_out_channels: int,
in_channels: int,
use_bias: bool,
ln_elementwise_affine: bool,
layer_norm_eps: float,
codebook_size: int,
):
super().__init__()
self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)
def forward(self, hidden_states):
hidden_states = self.conv1(hidden_states)
hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
logits = self.conv2(hidden_states)
return logits

View File

@@ -20,8 +20,8 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
@dataclass
@@ -88,6 +88,9 @@ class VQModel(ModelMixin, ConfigMixin):
vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215,
norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
lookup_from_codebook=False,
force_upcast=False,
):
super().__init__()
@@ -101,6 +104,7 @@ class VQModel(ModelMixin, ConfigMixin):
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=False,
mid_block_add_attention=mid_block_add_attention,
)
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
@@ -119,6 +123,7 @@ class VQModel(ModelMixin, ConfigMixin):
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_type=norm_type,
mid_block_add_attention=mid_block_add_attention,
)
@apply_forward_hook
@@ -133,11 +138,13 @@ class VQModel(ModelMixin, ConfigMixin):
@apply_forward_hook
def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, _, _ = self.quantize(h)
elif self.config.lookup_from_codebook:
quant = self.quantize.get_codebook_entry(h, shape)
else:
quant = h
quant2 = self.post_quant_conv(quant)

View File

@@ -20,6 +20,7 @@ _dummy_objects = {}
_import_structure = {
"controlnet": [],
"controlnet_xs": [],
"deprecated": [],
"latent_diffusion": [],
"stable_diffusion": [],
"stable_diffusion_xl": [],
@@ -44,16 +45,20 @@ else:
_import_structure["ddpm"] = ["DDPMPipeline"]
_import_structure["dit"] = ["DiTPipeline"]
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
_import_structure["pipeline_utils"] = [
"AudioPipelineOutput",
"DiffusionPipeline",
"ImagePipelineOutput",
]
_import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
_import_structure["deprecated"].extend(
[
"PNDMPipeline",
"LDMPipeline",
"RePaintPipeline",
"ScoreSdeVePipeline",
"KarrasVePipeline",
]
)
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
@@ -62,7 +67,23 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
else:
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
_import_structure["deprecated"].extend(["AudioDiffusionPipeline", "Mel"])
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
_import_structure["deprecated"].extend(
[
"MidiProcessor",
"SpectrogramDiffusionPipeline",
]
)
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
@@ -71,10 +92,23 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["alt_diffusion"] = [
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
]
_import_structure["deprecated"].extend(
[
"VQDiffusionPipeline",
"AltDiffusionPipeline",
"AltDiffusionImg2ImgPipeline",
"CycleDiffusionPipeline",
"StableDiffusionInpaintPipelineLegacy",
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionParadigmsPipeline",
"StableDiffusionModelEditingPipeline",
"VersatileDiffusionDualGuidedPipeline",
"VersatileDiffusionImageVariationPipeline",
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
]
)
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = ["AnimateDiffPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -146,32 +180,26 @@ else:
_import_structure["stable_diffusion"].extend(
[
"CLIPImageProjection",
"CycleDiffusionPipeline",
"StableDiffusionAttendAndExcitePipeline",
"StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENTextImagePipeline",
"StableDiffusionImageVariationPipeline",
"StableDiffusionImg2ImgPipeline",
"StableDiffusionInpaintPipeline",
"StableDiffusionInpaintPipelineLegacy",
"StableDiffusionInstructPix2PixPipeline",
"StableDiffusionLatentUpscalePipeline",
"StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline",
"StableDiffusionPipeline",
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
"StableDiffusionLDM3DPipeline",
]
)
_import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
_import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
_import_structure["stable_diffusion_gligen"] = [
"StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENTextImagePipeline",
]
_import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
_import_structure["stable_diffusion_xl"].extend(
[
@@ -181,6 +209,9 @@ else:
"StableDiffusionXLPipeline",
]
)
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
_import_structure["stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"]
_import_structure["stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
_import_structure["t2i_adapter"] = [
"StableDiffusionAdapterPipeline",
"StableDiffusionXLAdapterPipeline",
@@ -198,13 +229,6 @@ else:
"UniDiffuserPipeline",
"UniDiffuserTextDecoder",
]
_import_structure["versatile_diffusion"] = [
"VersatileDiffusionDualGuidedPipeline",
"VersatileDiffusionImageVariationPipeline",
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
]
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
_import_structure["wuerstchen"] = [
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
@@ -231,7 +255,6 @@ else:
[
"OnnxStableDiffusionImg2ImgPipeline",
"OnnxStableDiffusionInpaintPipeline",
"OnnxStableDiffusionInpaintPipelineLegacy",
"OnnxStableDiffusionPipeline",
"OnnxStableDiffusionUpscalePipeline",
"StableDiffusionOnnxPipeline",
@@ -248,7 +271,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
_import_structure["stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
@@ -279,18 +302,6 @@ else:
"FlaxStableDiffusionXLPipeline",
]
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
_import_structure["spectrogram_diffusion"] = [
"MidiProcessor",
"SpectrogramDiffusionPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -309,18 +320,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pipeline_utils import (
AudioPipelineOutput,
DiffusionPipeline,
ImagePipelineOutput,
)
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
try:
if not (is_torch_available() and is_librosa_available()):
@@ -328,7 +335,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_librosa_objects import *
else:
from .audio_diffusion import AudioDiffusionPipeline, Mel
from .deprecated import AudioDiffusionPipeline, Mel
try:
if not (is_torch_available() and is_transformers_available()):
@@ -336,7 +343,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import AnimateDiffPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import (
@@ -366,6 +373,20 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
IFPipeline,
IFSuperResolutionPipeline,
)
from .deprecated import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
CycleDiffusionPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionModelEditingPipeline,
StableDiffusionParadigmsPipeline,
StableDiffusionPix2PixZeroPipeline,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
from .kandinsky import (
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
@@ -403,30 +424,24 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_diffusion import (
CLIPImageProjection,
CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
StableDiffusionGLIGENTextImagePipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
StableDiffusionPipeline,
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline
from .stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
from .stable_diffusion_panorama import StableDiffusionPanoramaPipeline
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_sag import StableDiffusionSAGPipeline
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
@@ -451,13 +466,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
@@ -482,7 +490,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline,
OnnxStableDiffusionUpscalePipeline,
StableDiffusionOnnxPipeline,
@@ -494,7 +501,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
from .stable_diffusion import StableDiffusionKDiffusionPipeline
from .stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
try:
if not is_flax_available():
@@ -527,7 +534,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
from .spectrogram_diffusion import (
from .deprecated import (
MidiProcessor,
SpectrogramDiffusionPipeline,
)

View File

@@ -0,0 +1,62 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
AmusedImg2ImgPipeline,
AmusedInpaintPipeline,
AmusedPipeline,
)
_dummy_objects.update(
{
"AmusedPipeline": AmusedPipeline,
"AmusedImg2ImgPipeline": AmusedImg2ImgPipeline,
"AmusedInpaintPipeline": AmusedInpaintPipeline,
}
)
else:
_import_structure["pipeline_amused"] = ["AmusedPipeline"]
_import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"]
_import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
AmusedPipeline,
)
else:
from .pipeline_amused import AmusedPipeline
from .pipeline_amused_img2img import AmusedImg2ImgPipeline
from .pipeline_amused_inpaint import AmusedInpaintPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,328 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AmusedPipeline
>>> pipe = AmusedPipeline.from_pretrained(
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
```
"""
class AmusedPipeline(DiffusionPipeline):
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
text_encoder: CLIPTextModelWithProjection
transformer: UVit2DModel
scheduler: AmusedScheduler
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
def __init__(
self,
vqvae: VQModel,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
transformer: UVit2DModel,
scheduler: AmusedScheduler,
):
super().__init__()
self.register_modules(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[List[str], str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 12,
guidance_scale: float = 10.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.IntTensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
output_type="pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
micro_conditioning_aesthetic_score: int = 6,
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
):
"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 16):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 10.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.IntTensor`, *optional*):
Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
gneration. If not provided, the starting latents will be completely masked.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
pooled and projected final hidden states.
encoder_hidden_states (`torch.FloatTensor`, *optional*):
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*):
Analogous to `encoder_hidden_states` for the positive prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
Examples:
Returns:
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
`tuple` is returned where the first element is a list with the generated images.
"""
if (prompt_embeds is not None and encoder_hidden_states is None) or (
prompt_embeds is None and encoder_hidden_states is not None
):
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
):
raise ValueError(
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
)
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
if isinstance(prompt, str):
prompt = [prompt]
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = batch_size * num_images_per_prompt
if height is None:
height = self.transformer.config.sample_size * self.vae_scale_factor
if width is None:
width = self.transformer.config.sample_size * self.vae_scale_factor
if prompt_embeds is None:
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
prompt_embeds = outputs.text_embeds
encoder_hidden_states = outputs.hidden_states[-2]
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
if guidance_scale > 1.0:
if negative_prompt_embeds is None:
if negative_prompt is None:
negative_prompt = [""] * len(prompt)
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
input_ids = self.tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
negative_prompt_embeds = outputs.text_embeds
negative_encoder_hidden_states = outputs.hidden_states[-2]
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
# Note that the micro conditionings _do_ flip the order of width, height for the original size
# and the crop coordinates. This is how it was done in the original code base
micro_conds = torch.tensor(
[
width,
height,
micro_conditioning_crop_coord[0],
micro_conditioning_crop_coord[1],
micro_conditioning_aesthetic_score,
],
device=self._execution_device,
dtype=encoder_hidden_states.dtype,
)
micro_conds = micro_conds.unsqueeze(0)
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
latents = torch.full(
shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device
)
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, timestep in enumerate(self.scheduler.timesteps):
if guidance_scale > 1.0:
model_input = torch.cat([latents] * 2)
else:
model_input = latents
model_output = self.transformer(
model_input,
micro_conds=micro_conds,
pooled_text_emb=prompt_embeds,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
if guidance_scale > 1.0:
uncond_logits, cond_logits = model_output.chunk(2)
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
latents = self.scheduler.step(
model_output=model_output,
timestep=timestep,
sample=latents,
generator=generator,
).prev_sample
if i == len(self.scheduler.timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if output_type == "latent":
output = latents
else:
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
if needs_upcasting:
self.vqvae.float()
output = self.vqvae.decode(
latents,
force_not_quantize=True,
shape=(
batch_size,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
self.vqvae.config.latent_channels,
),
).sample.clip(0, 1)
output = self.image_processor.postprocess(output, output_type)
if needs_upcasting:
self.vqvae.half()
self.maybe_free_model_hooks()
if not return_dict:
return (output,)
return ImagePipelineOutput(output)

View File

@@ -0,0 +1,347 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AmusedImg2ImgPipeline
>>> from diffusers.utils import load_image
>>> pipe = AmusedImg2ImgPipeline.from_pretrained(
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "winter mountains"
>>> input_image = (
... load_image(
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg"
... )
... .resize((512, 512))
... .convert("RGB")
... )
>>> image = pipe(prompt, input_image).images[0]
```
"""
class AmusedImg2ImgPipeline(DiffusionPipeline):
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
text_encoder: CLIPTextModelWithProjection
transformer: UVit2DModel
scheduler: AmusedScheduler
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
# off the meta device. There should be a way to fix this instead of just not offloading it
_exclude_from_cpu_offload = ["vqvae"]
def __init__(
self,
vqvae: VQModel,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
transformer: UVit2DModel,
scheduler: AmusedScheduler,
):
super().__init__()
self.register_modules(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[List[str], str]] = None,
image: PipelineImageInput = None,
strength: float = 0.5,
num_inference_steps: int = 12,
guidance_scale: float = 10.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[torch.Generator] = None,
prompt_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
output_type="pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
micro_conditioning_aesthetic_score: int = 6,
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
):
"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 0.5):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 16):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 10.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
pooled and projected final hidden states.
encoder_hidden_states (`torch.FloatTensor`, *optional*):
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*):
Analogous to `encoder_hidden_states` for the positive prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
Examples:
Returns:
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
`tuple` is returned where the first element is a list with the generated images.
"""
if (prompt_embeds is not None and encoder_hidden_states is None) or (
prompt_embeds is None and encoder_hidden_states is not None
):
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
):
raise ValueError(
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
)
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
if isinstance(prompt, str):
prompt = [prompt]
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = batch_size * num_images_per_prompt
if prompt_embeds is None:
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
prompt_embeds = outputs.text_embeds
encoder_hidden_states = outputs.hidden_states[-2]
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
if guidance_scale > 1.0:
if negative_prompt_embeds is None:
if negative_prompt is None:
negative_prompt = [""] * len(prompt)
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
input_ids = self.tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
negative_prompt_embeds = outputs.text_embeds
negative_encoder_hidden_states = outputs.hidden_states[-2]
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
image = self.image_processor.preprocess(image)
height, width = image.shape[-2:]
# Note that the micro conditionings _do_ flip the order of width, height for the original size
# and the crop coordinates. This is how it was done in the original code base
micro_conds = torch.tensor(
[
width,
height,
micro_conditioning_crop_coord[0],
micro_conditioning_crop_coord[1],
micro_conditioning_aesthetic_score,
],
device=self._execution_device,
dtype=encoder_hidden_states.dtype,
)
micro_conds = micro_conds.unsqueeze(0)
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
num_inference_steps = int(len(self.scheduler.timesteps) * strength)
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
if needs_upcasting:
self.vqvae.float()
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
latents_bsz, channels, latents_height, latents_width = latents.shape
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
latents = self.scheduler.add_noise(
latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator
)
latents = latents.repeat(num_images_per_prompt, 1, 1)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
timestep = self.scheduler.timesteps[i]
if guidance_scale > 1.0:
model_input = torch.cat([latents] * 2)
else:
model_input = latents
model_output = self.transformer(
model_input,
micro_conds=micro_conds,
pooled_text_emb=prompt_embeds,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
if guidance_scale > 1.0:
uncond_logits, cond_logits = model_output.chunk(2)
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
latents = self.scheduler.step(
model_output=model_output,
timestep=timestep,
sample=latents,
generator=generator,
).prev_sample
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if output_type == "latent":
output = latents
else:
output = self.vqvae.decode(
latents,
force_not_quantize=True,
shape=(
batch_size,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
self.vqvae.config.latent_channels,
),
).sample.clip(0, 1)
output = self.image_processor.postprocess(output, output_type)
if needs_upcasting:
self.vqvae.half()
self.maybe_free_model_hooks()
if not return_dict:
return (output,)
return ImagePipelineOutput(output)

View File

@@ -0,0 +1,378 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AmusedInpaintPipeline
>>> from diffusers.utils import load_image
>>> pipe = AmusedInpaintPipeline.from_pretrained(
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "fall mountains"
>>> input_image = (
... load_image(
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg"
... )
... .resize((512, 512))
... .convert("RGB")
... )
>>> mask = (
... load_image(
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
... )
... .resize((512, 512))
... .convert("L")
... )
>>> pipe(prompt, input_image, mask).images[0].save("out.png")
```
"""
class AmusedInpaintPipeline(DiffusionPipeline):
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
text_encoder: CLIPTextModelWithProjection
transformer: UVit2DModel
scheduler: AmusedScheduler
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
# off the meta device. There should be a way to fix this instead of just not offloading it
_exclude_from_cpu_offload = ["vqvae"]
def __init__(
self,
vqvae: VQModel,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
transformer: UVit2DModel,
scheduler: AmusedScheduler,
):
super().__init__()
self.register_modules(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
do_resize=True,
)
self.scheduler.register_to_config(masking_schedule="linear")
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[List[str], str]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
strength: float = 1.0,
num_inference_steps: int = 12,
guidance_scale: float = 10.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[torch.Generator] = None,
prompt_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
output_type="pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
micro_conditioning_aesthetic_score: int = 6,
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
):
"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 16):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 10.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
pooled and projected final hidden states.
encoder_hidden_states (`torch.FloatTensor`, *optional*):
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*):
Analogous to `encoder_hidden_states` for the positive prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
Examples:
Returns:
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
`tuple` is returned where the first element is a list with the generated images.
"""
if (prompt_embeds is not None and encoder_hidden_states is None) or (
prompt_embeds is None and encoder_hidden_states is not None
):
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
):
raise ValueError(
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
)
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
if isinstance(prompt, str):
prompt = [prompt]
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = batch_size * num_images_per_prompt
if prompt_embeds is None:
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
prompt_embeds = outputs.text_embeds
encoder_hidden_states = outputs.hidden_states[-2]
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
if guidance_scale > 1.0:
if negative_prompt_embeds is None:
if negative_prompt is None:
negative_prompt = [""] * len(prompt)
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
input_ids = self.tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
negative_prompt_embeds = outputs.text_embeds
negative_encoder_hidden_states = outputs.hidden_states[-2]
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
image = self.image_processor.preprocess(image)
height, width = image.shape[-2:]
# Note that the micro conditionings _do_ flip the order of width, height for the original size
# and the crop coordinates. This is how it was done in the original code base
micro_conds = torch.tensor(
[
width,
height,
micro_conditioning_crop_coord[0],
micro_conditioning_crop_coord[1],
micro_conditioning_aesthetic_score,
],
device=self._execution_device,
dtype=encoder_hidden_states.dtype,
)
micro_conds = micro_conds.unsqueeze(0)
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
num_inference_steps = int(len(self.scheduler.timesteps) * strength)
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
if needs_upcasting:
self.vqvae.float()
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
latents_bsz, channels, latents_height, latents_width = latents.shape
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
mask = self.mask_processor.preprocess(
mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor
)
mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device)
latents[mask] = self.scheduler.config.mask_token_id
starting_mask_ratio = mask.sum() / latents.numel()
latents = latents.repeat(num_images_per_prompt, 1, 1)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
timestep = self.scheduler.timesteps[i]
if guidance_scale > 1.0:
model_input = torch.cat([latents] * 2)
else:
model_input = latents
model_output = self.transformer(
model_input,
micro_conds=micro_conds,
pooled_text_emb=prompt_embeds,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
if guidance_scale > 1.0:
uncond_logits, cond_logits = model_output.chunk(2)
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
latents = self.scheduler.step(
model_output=model_output,
timestep=timestep,
sample=latents,
generator=generator,
starting_mask_ratio=starting_mask_ratio,
).prev_sample
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if output_type == "latent":
output = latents
else:
output = self.vqvae.decode(
latents,
force_not_quantize=True,
shape=(
batch_size,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
self.vqvae.config.latent_channels,
),
).sample.clip(0, 1)
output = self.image_processor.postprocess(output, output_type)
if needs_upcasting:
self.vqvae.half()
self.maybe_free_model_hooks()
if not return_dict:
return (output,)
return ImagePipelineOutput(output)

View File

@@ -106,7 +106,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
def __init__(

View File

@@ -176,7 +176,7 @@ class StableDiffusionControlNetPipeline(
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -633,7 +633,7 @@ class StableDiffusionControlNetPipeline(
# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
@@ -659,7 +659,7 @@ class StableDiffusionControlNetPipeline(
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):

View File

@@ -291,7 +291,7 @@ class StableDiffusionControlNetInpaintPipeline(
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

View File

@@ -165,7 +165,7 @@ class StableDiffusionXLControlNetPipeline(
"""
# leave controlnet out on purpose because it iterates with unet
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
_optional_components = [
"tokenizer",
"tokenizer_2",

View File

@@ -0,0 +1,3 @@
# Deprecated Pipelines
This folder contains pipelines that have very low usage as measured by model downloads, issues and PRs. While you can still use the pipelines just as before, we will stop testing the pipelines and will not accept any changes to existing files.

View File

@@ -0,0 +1,153 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_librosa_available,
is_note_seq_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_pt_objects
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
_import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["alt_diffusion"] = [
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AltDiffusionPipelineOutput",
]
_import_structure["versatile_diffusion"] = [
"VersatileDiffusionDualGuidedPipeline",
"VersatileDiffusionImageVariationPipeline",
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
]
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
_import_structure["stable_diffusion_variants"] = [
"CycleDiffusionPipeline",
"StableDiffusionInpaintPipelineLegacy",
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionParadigmsPipeline",
"StableDiffusionModelEditingPipeline",
]
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_librosa_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
else:
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_pt_objects import *
else:
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AltDiffusionPipelineOutput
from .audio_diffusion import AudioDiffusionPipeline, Mel
from .spectrogram_diffusion import SpectrogramDiffusionPipeline
from .stable_diffusion_variants import (
CycleDiffusionPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionModelEditingPipeline,
StableDiffusionParadigmsPipeline,
StableDiffusionPix2PixZeroPipeline,
)
from .stochastic_karras_ve import KarrasVePipeline
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_librosa_objects import *
else:
from .audio_diffusion import AudioDiffusionPipeline, Mel
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
from .spectrogram_diffusion import (
MidiProcessor,
SpectrogramDiffusionPipeline,
)
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import (
from ....utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
@@ -17,7 +17,7 @@ try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
from ....utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
@@ -32,7 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
from ....utils.dummy_torch_and_transformers_objects import *
else:
from .modeling_roberta_series import RobertaSeriesModelWithTransformation

View File

@@ -19,13 +19,14 @@ import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
from ....configuration_utils import FrozenDict
from ....image_processor import PipelineImageInput, VaeImageProcessor
from ....loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ....models.attention_processor import FusedAttnProcessor2_0
from ....models.lora import adjust_lora_scale_text_encoder
from ....schedulers import KarrasDiffusionSchedulers
from ....utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
@@ -33,9 +34,9 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ....utils.torch_utils import randn_tensor
from ...pipeline_utils import DiffusionPipeline
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
from .pipeline_output import AltDiffusionPipelineOutput
@@ -118,7 +119,6 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
@@ -155,7 +155,7 @@ class AltDiffusionPipeline(
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -655,6 +655,65 @@ class AltDiffusionPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

View File

@@ -21,13 +21,14 @@ import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
from ....configuration_utils import FrozenDict
from ....image_processor import PipelineImageInput, VaeImageProcessor
from ....loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ....models.attention_processor import FusedAttnProcessor2_0
from ....models.lora import adjust_lora_scale_text_encoder
from ....schedulers import KarrasDiffusionSchedulers
from ....utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
@@ -36,9 +37,9 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ....utils.torch_utils import randn_tensor
from ...pipeline_utils import DiffusionPipeline
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
from .pipeline_output import AltDiffusionPipelineOutput
@@ -158,7 +159,6 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
@@ -195,7 +195,7 @@ class AltDiffusionImg2ImgPipeline(
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -715,6 +715,65 @@ class AltDiffusionImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Union
import numpy as np
import PIL.Image
from ...utils import (
from ....utils import (
BaseOutput,
)

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
_import_structure = {

View File

@@ -15,8 +15,8 @@
import numpy as np # noqa: E402
from ...configuration_utils import ConfigMixin, register_to_config
from ...schedulers.scheduling_utils import SchedulerMixin
from ....configuration_utils import ConfigMixin, register_to_config
from ....schedulers.scheduling_utils import SchedulerMixin
try:

View File

@@ -20,10 +20,10 @@ import numpy as np
import torch
from PIL import Image
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
from ....models import AutoencoderKL, UNet2DConditionModel
from ....schedulers import DDIMScheduler, DDPMScheduler
from ....utils.torch_utils import randn_tensor
from ...pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
from .mel import Mel

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
_import_structure = {"pipeline_latent_diffusion_uncond": ["LDMPipeline"]}

View File

@@ -17,10 +17,10 @@ from typing import List, Optional, Tuple, Union
import torch
from ...models import UNet2DModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ....models import UNet2DModel, VQModel
from ....schedulers import DDIMScheduler
from ....utils.torch_utils import randn_tensor
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class LDMPipeline(DiffusionPipeline):

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
_import_structure = {"pipeline_pndm": ["PNDMPipeline"]}

View File

@@ -17,10 +17,10 @@ from typing import List, Optional, Tuple, Union
import torch
from ...models import UNet2DModel
from ...schedulers import PNDMScheduler
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ....models import UNet2DModel
from ....schedulers import PNDMScheduler
from ....utils.torch_utils import randn_tensor
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class PNDMPipeline(DiffusionPipeline):

Some files were not shown because too many files have changed in this diff Show More