Compare commits

...

13 Commits

Author SHA1 Message Date
sayakpaul
1b98e10614 guard accelerate import. 2026-01-07 09:50:34 +05:30
Sayak Paul
2f35e145e3 Merge branch 'main' into fsdp 2026-01-07 09:48:39 +05:30
Sayak Paul
3b2e491d13 Merge branch 'main' into fsdp 2026-01-07 09:38:39 +05:30
Sayak Paul
f392c60cde Update examples/dreambooth/README_flux2.md 2026-01-07 09:38:28 +05:30
js1234567
8da9ea7d4a Add FSDP option for Flux2 2026-01-07 09:53:11 +08:00
Sayak Paul
b54e0e634e Merge branch 'main' into fsdp 2026-01-06 20:29:42 +05:30
js1234567
af339debf4 Add FSDP option for Flux2 2025-12-24 17:11:05 +08:00
js1234567
6cfac4642f Add FSDP option for Flux2 2025-12-24 15:43:21 +08:00
js1234567
8bce38c086 Add FSDP option for Flux2 2025-12-23 16:23:48 +08:00
js1234567
f931ec31a5 Add FSDP option for Flux2 2025-12-23 15:56:13 +08:00
github-actions[bot]
647c66aaf3 Apply style fixes 2025-12-22 05:42:17 +00:00
Sayak Paul
0052b21f52 Merge branch 'main' into fsdp 2025-12-22 10:50:28 +05:30
js1234567
c766e27c77 Add FSDP option for Flux2 2025-12-19 15:12:32 +08:00
4 changed files with 339 additions and 51 deletions

View File

@@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
This way, the text encoder model is not loaded into memory during training.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### FSDP Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
This way, it distributes the memory cost across multiple nodes.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
@@ -166,6 +169,26 @@ To better track our training experiments, we're using the following flags in the
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
```shell
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.

View File

@@ -44,6 +44,7 @@ import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -75,13 +76,16 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
if is_wandb_available():
import wandb
@@ -722,6 +729,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1219,7 +1227,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1263,17 +1275,42 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
transformer_cls = type(unwrap_model(transformer))
# make sure to pop weight so that corresponding model is not saved again
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1285,13 +1322,20 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1507,6 +1551,21 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1536,6 +1595,8 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1777,7 +1838,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1836,15 +1897,41 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -43,6 +43,7 @@ import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -74,13 +75,16 @@ from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
if is_wandb_available():
import wandb
@@ -691,6 +698,7 @@ def parse_args(input_args=None):
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1156,7 +1164,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1200,17 +1212,42 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
transformer_cls = type(unwrap_model(transformer))
# make sure to pop weight so that corresponding model is not saved again
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1222,13 +1259,20 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1430,6 +1474,21 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1461,6 +1520,8 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1700,7 +1761,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1759,15 +1820,41 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -6,11 +6,18 @@ import random
import re
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import numpy as np
import torch
if getattr(torch, "distributed", None) is not None:
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from .models import UNet2DConditionModel
from .pipelines import DiffusionPipeline
from .schedulers import SchedulerMixin
@@ -18,6 +25,7 @@ from .utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
is_accelerate_available,
is_peft_available,
is_torch_npu_available,
is_torchvision_available,
@@ -31,6 +39,9 @@ if is_transformers_available():
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
if is_accelerate_available():
from accelerate.logging import get_logger
if is_peft_available():
from peft import set_peft_model_state_dict
@@ -394,6 +405,86 @@ def find_nearest_bucket(h, w, bucket_options):
return best_bucket_idx
def _to_cpu_contiguous(state_dicts) -> dict:
return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()}
def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
"""
Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
"""
kwargs = {}
fsdp_state = getattr(accelerator.state, "fsdp_plugin", None)
if fsdp_state is None:
raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.")
fsdp_plugin = accelerator.state.fsdp_plugin
if fsdp_plugin is None:
# FSDP not enabled in Accelerator
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
else:
# FSDP is enabled → use plugin's strategy, or default if None
kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
return kwargs
def wrap_with_fsdp(
model: torch.nn.Module,
device: Union[str, torch.device],
offload: bool = True,
use_orig_params: bool = True,
limit_all_gathers: bool = True,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
) -> FSDP:
"""
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
Args:
model: Model to wrap
device: Target device (e.g., accelerator.device)
offload: Whether to enable CPU parameter offloading
use_orig_params: Whether to use original parameters
limit_all_gathers: Whether to limit all gathers
fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)
Returns:
FSDP-wrapped model
"""
logger = get_logger(__name__)
if transformer_layer_cls is None:
# Set the default layers if transformer_layer_cls is not provided
transformer_layer_cls = type(model.model.language_model.layers[0])
logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}")
# Add auto-wrap policy if transformer layers specified
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={transformer_layer_cls},
)
config = {
"device_id": device,
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
"use_orig_params": use_orig_params,
"limit_all_gathers": limit_all_gathers,
"auto_wrap_policy": auto_wrap_policy,
}
if fsdp_kwargs:
config.update(fsdp_kwargs)
fsdp_model = FSDP(model, **config)
return fsdp_model
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""