mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-11 14:15:40 +08:00
Compare commits
13 Commits
modular-cu
...
leisuzz-fs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b98e10614 | ||
|
|
2f35e145e3 | ||
|
|
3b2e491d13 | ||
|
|
f392c60cde | ||
|
|
8da9ea7d4a | ||
|
|
b54e0e634e | ||
|
|
af339debf4 | ||
|
|
6cfac4642f | ||
|
|
8bce38c086 | ||
|
|
f931ec31a5 | ||
|
|
647c66aaf3 | ||
|
|
0052b21f52 | ||
|
|
c766e27c77 |
@@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
### FSDP Text Encoder
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
|
||||
This way, it distributes the memory cost across multiple nodes.
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
### Latent Caching
|
||||
@@ -166,6 +169,26 @@ To better track our training experiments, we're using the following flags in the
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
### FSDP on the transformer
|
||||
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
|
||||
|
||||
```shell
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: HYBRID_SHARD
|
||||
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_sync_module_states: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_activation_checkpointing: true
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
```
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
@@ -44,6 +44,7 @@ import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -75,13 +76,16 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -722,6 +729,7 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1219,7 +1227,11 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1263,17 +1275,42 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
@@ -1285,13 +1322,20 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1507,6 +1551,21 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1536,6 +1595,8 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1777,7 +1838,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1836,15 +1897,41 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
@@ -43,6 +43,7 @@ import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -74,13 +75,16 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -691,6 +698,7 @@ def parse_args(input_args=None):
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1156,7 +1164,11 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1200,17 +1212,42 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
@@ -1222,13 +1259,20 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1430,6 +1474,21 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1461,6 +1520,8 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1700,7 +1761,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1759,15 +1820,41 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
@@ -6,11 +6,18 @@ import random
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
|
||||
from .models import UNet2DConditionModel
|
||||
from .pipelines import DiffusionPipeline
|
||||
from .schedulers import SchedulerMixin
|
||||
@@ -18,6 +25,7 @@ from .utils import (
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_torch_npu_available,
|
||||
is_torchvision_available,
|
||||
@@ -31,6 +39,9 @@ if is_transformers_available():
|
||||
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
if is_peft_available():
|
||||
from peft import set_peft_model_state_dict
|
||||
|
||||
@@ -394,6 +405,86 @@ def find_nearest_bucket(h, w, bucket_options):
|
||||
return best_bucket_idx
|
||||
|
||||
|
||||
def _to_cpu_contiguous(state_dicts) -> dict:
|
||||
return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()}
|
||||
|
||||
|
||||
def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
|
||||
"""
|
||||
Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
fsdp_state = getattr(accelerator.state, "fsdp_plugin", None)
|
||||
|
||||
if fsdp_state is None:
|
||||
raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.")
|
||||
|
||||
fsdp_plugin = accelerator.state.fsdp_plugin
|
||||
|
||||
if fsdp_plugin is None:
|
||||
# FSDP not enabled in Accelerator
|
||||
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
|
||||
else:
|
||||
# FSDP is enabled → use plugin's strategy, or default if None
|
||||
kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def wrap_with_fsdp(
|
||||
model: torch.nn.Module,
|
||||
device: Union[str, torch.device],
|
||||
offload: bool = True,
|
||||
use_orig_params: bool = True,
|
||||
limit_all_gathers: bool = True,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
|
||||
) -> FSDP:
|
||||
"""
|
||||
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
|
||||
|
||||
Args:
|
||||
model: Model to wrap
|
||||
device: Target device (e.g., accelerator.device)
|
||||
offload: Whether to enable CPU parameter offloading
|
||||
use_orig_params: Whether to use original parameters
|
||||
limit_all_gathers: Whether to limit all gathers
|
||||
fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
|
||||
transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)
|
||||
|
||||
Returns:
|
||||
FSDP-wrapped model
|
||||
"""
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if transformer_layer_cls is None:
|
||||
# Set the default layers if transformer_layer_cls is not provided
|
||||
transformer_layer_cls = type(model.model.language_model.layers[0])
|
||||
logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}")
|
||||
|
||||
# Add auto-wrap policy if transformer layers specified
|
||||
auto_wrap_policy = partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={transformer_layer_cls},
|
||||
)
|
||||
|
||||
config = {
|
||||
"device_id": device,
|
||||
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
|
||||
"use_orig_params": use_orig_params,
|
||||
"limit_all_gathers": limit_all_gathers,
|
||||
"auto_wrap_policy": auto_wrap_policy,
|
||||
}
|
||||
|
||||
if fsdp_kwargs:
|
||||
config.update(fsdp_kwargs)
|
||||
|
||||
fsdp_model = FSDP(model, **config)
|
||||
return fsdp_model
|
||||
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
class EMAModel:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user