mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[Flux] Add advanced training script + support textual inversion inference (#9434)
* add ostris trainer to README & add cache latents of vae
* add ostris trainer to README & add cache latents of vae
* style
* readme
* add test for latent caching
* add ostris noise scheduler
9ee1ef2a0a/toolkit/samplers/custom_flowmatch_sampler.py (L95)
* style
* fix import
* style
* fix tests
* style
* --change upcasting of transformer?
* update readme according to main
* add pivotal tuning for CLIP
* fix imports, encode_prompt call,add TextualInversionLoaderMixin to FluxPipeline for inference
* TextualInversionLoaderMixin support for FluxPipeline for inference
* move changes to advanced flux script, revert canonical
* add latent caching to canonical script
* revert changes to canonical script to keep it separate from https://github.com/huggingface/diffusers/pull/9160
* revert changes to canonical script to keep it separate from https://github.com/huggingface/diffusers/pull/9160
* style
* remove redundant line and change code block placement to align with logic
* add initializer_token arg
* add transformer frac for range support from pure textual inversion to the orig pivotal tuning
* support pure textual inversion - wip
* adjustments to support pure textual inversion and transformer optimization in only part of the epochs
* fix logic when using initializer token
* fix pure_textual_inversion_condition
* fix ti/pivotal loading of last validation run
* remove embeddings loading for ti in final training run (to avoid adding huggingface hub dependency)
* support pivotal for t5
* adapt pivotal for T5 encoder
* adapt pivotal for T5 encoder and support in flux pipeline
* t5 pivotal support + support fo pivotal for clip only or both
* fix param chaining
* fix param chaining
* README first draft
* readme
* readme
* readme
* style
* fix import
* style
* add fix from https://github.com/huggingface/diffusers/pull/9419
* add to readme, change function names
* te lr changes
* readme
* change concept tokens logic
* fix indices
* change arg name
* style
* dummy test
* revert dummy test
* reorder pivoting
* add warning in case the token abstraction is not the instance prompt
* experimental - wip - specific block training
* fix documentation and token abstraction processing
* remove transformer block specification feature (for now)
* style
* fix copies
* fix indexing issue when --initializer_concept has different amounts
* add if TextualInversionLoaderMixin to all flux pipelines
* style
* fix import
* fix imports
* address review comments - remove necessary prints & comments, use pin_memory=True, use free_memory utils, unify warning and prints
* style
* logger info fix
* make lora target modules configurable and change the default
* make lora target modules configurable and change the default
* style
* make lora target modules configurable and change the default, add notes to readme
* style
* add tests
* style
* fix repo id
* add updated requirements for advanced flux
* fix indices of t5 pivotal tuning embeddings
* fix path in test
* remove `pin_memory`
* fix filename of embedding
* fix filename of embedding
---------
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
353
examples/advanced_diffusion_training/README_flux.md
Normal file
353
examples/advanced_diffusion_training/README_flux.md
Normal file
@@ -0,0 +1,353 @@
|
||||
# Advanced diffusion training examples
|
||||
|
||||
## Train Dreambooth LoRA with Flux.1 Dev
|
||||
> [!TIP]
|
||||
> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script).
|
||||
> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject.
|
||||
|
||||
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
|
||||
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
|
||||
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
|
||||
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
||||
|
||||
The `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with
|
||||
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
|
||||
[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
|
||||
|
||||
> [!NOTE]
|
||||
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
|
||||
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/advanced_diffusion_training` folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
### Pivotal Tuning (and more)
|
||||
**Training with text encoder(s)**
|
||||
|
||||
Alongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
|
||||
available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported.
|
||||
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
|
||||
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
|
||||
We then optimize the newly-inserted token embeddings to represent the new concept.
|
||||
|
||||
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
|
||||
Please keep the following points in mind:
|
||||
|
||||
* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only.
|
||||
To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`.
|
||||
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
|
||||
* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results!
|
||||
* **token initializer** - similar to the original textual inversion work, you can specify a token of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_token` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_token`.
|
||||
|
||||
## Training examples
|
||||
|
||||
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./3d_icon"
|
||||
snapshot_download(
|
||||
"LinoyTsaban/3d_icon",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
Let's review some of the advanced features we're going to be using for this example:
|
||||
- **custom captions**:
|
||||
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
|
||||
```bash
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
|
||||
|
||||
```
|
||||
--dataset_name=./3d_icon
|
||||
--caption_column=prompt
|
||||
```
|
||||
|
||||
You can also load a dataset straight from by specifying it's name in `dataset_name`.
|
||||
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
|
||||
|
||||
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
|
||||
- **pivotal tuning**
|
||||
|
||||
### Example #1: Pivotal tuning
|
||||
**Now, we can launch training:**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-Flux-LoRA"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
### Example #2: Pivotal tuning with T5
|
||||
Now let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize
|
||||
the T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration:
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-Flux-LoRA"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--enable_t5_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Example #3: Textual Inversion
|
||||
To explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we
|
||||
can set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is
|
||||
trained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train
|
||||
run.
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-Flux-LoRA"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--enable_t5_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--train_transformer_frac=0\
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
### Inference - pivotal tuning
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
1. starting with loading the transformer lora weights
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
username = "linoyts"
|
||||
repo_id = f"{username}/3d-icon-Flux-LoRA"
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
|
||||
|
||||
|
||||
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
|
||||
```
|
||||
2. now we load the pivotal tuning embeddings
|
||||
💡note that if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
|
||||
|
||||
```python
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
|
||||
|
||||
state_dict = load_file(embedding_path)
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
|
||||
pipe.load_textual_inversion(state_dict["t5"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
```
|
||||
|
||||
3. let's generate images
|
||||
|
||||
```python
|
||||
instance_token = "<s0><s1>"
|
||||
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
|
||||
|
||||
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image.save("llama.png")
|
||||
```
|
||||
|
||||
### Inference - pure textual inversion
|
||||
In this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a
|
||||
`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5)
|
||||
|
||||
1. starting with loading the embeddings.
|
||||
💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
username = "linoyts"
|
||||
repo_id = f"{username}/3d-icon-Flux-LoRA"
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
|
||||
|
||||
state_dict = load_file(embedding_path)
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
|
||||
pipe.load_textual_inversion(state_dict["t5"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
```
|
||||
2. let's generate images
|
||||
|
||||
```python
|
||||
instance_token = "<s0><s1>"
|
||||
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
|
||||
|
||||
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image.save("llama.png")
|
||||
```
|
||||
|
||||
### Comfy UI / AUTOMATIC1111 Inference
|
||||
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
|
||||
|
||||
**AUTOMATIC1111 / SD.Next** \
|
||||
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
|
||||
|
||||
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
|
||||
|
||||
**ComfyUI** \
|
||||
In ComfyUI we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
|
||||
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.31.0
|
||||
torchvision
|
||||
transformers>=4.41.2
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft>=0.11.1
|
||||
sentencepiece
|
||||
@@ -0,0 +1,283 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py"
|
||||
|
||||
def test_dreambooth_lora_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_text_encoder_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
starts_with_expected_prefix = all(
|
||||
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_expected_prefix)
|
||||
|
||||
def test_dreambooth_lora_pivotal_tuning_flux_clip(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder_ti
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
# make sure embeddings were also saved
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
textual_inversion_state_dict = safetensors.torch.load_file(
|
||||
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
|
||||
)
|
||||
is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys())
|
||||
self.assertTrue(is_clip)
|
||||
|
||||
# when performing pivotal tuning, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder_ti
|
||||
--enable_t5_ti
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
# make sure embeddings were also saved
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
textual_inversion_state_dict = safetensors.torch.load_file(
|
||||
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
|
||||
)
|
||||
is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
|
||||
self.assertTrue(is_te)
|
||||
|
||||
# when performing pivotal tuning, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -137,7 +137,12 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
class FluxPipeline(
|
||||
DiffusionPipeline,
|
||||
FluxLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
The Flux pipeline for text-to-image generation.
|
||||
|
||||
@@ -212,6 +217,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -255,6 +263,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
@@ -25,7 +25,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
@@ -238,6 +238,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -281,6 +284,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
@@ -251,6 +251,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -295,6 +298,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
@@ -261,6 +261,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -305,6 +308,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -235,6 +235,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -279,6 +282,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -239,6 +239,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -283,6 +286,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
Reference in New Issue
Block a user