mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-09 09:17:10 +08:00
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation: - BlockRefinementPipeline: block-wise iterative refinement with confidence-based token commitment, supporting editing threshold for LLaDA2.1 models - LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults - DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p, temperature) and prompt/prefix helpers - compute_confidence_aware_loss: CAP-style training loss - Examples: sampling scripts for LLaDA2 and block refinement, training scripts with Qwen causal LM - Docs and tests included * feat: add BlockRefinementScheduler for commit-by-confidence scheduling Extract the confidence-based token commit logic from BlockRefinementPipeline into a dedicated BlockRefinementScheduler, following diffusers conventions. The scheduler owns: - Transfer schedule computation (get_num_transfer_tokens) - Timestep management (set_timesteps) - Step logic: confidence-based mask-filling and optional token editing The pipeline now delegates scheduling to self.scheduler.step() and accepts a scheduler parameter in __init__. * test: add unit tests for BlockRefinementScheduler 12 tests covering set_timesteps, get_num_transfer_tokens, step logic (confidence-based commits, threshold behavior, editing, prompt masking, batched inputs, tuple output). * docs: add toctree entries and standalone scheduler doc page - Add BlockRefinement and LLaDA2 to docs sidebar navigation - Add BlockRefinementScheduler to schedulers sidebar navigation - Move scheduler autodoc to its own page under api/schedulers/ * feat: add --revision flag and fix dtype deprecation in sample_llada2.py - Add --revision argument for loading model revisions from the Hub - Replace deprecated torch_dtype with dtype for transformers 5.x compat * fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat LLaDA2 models expect a boolean-style (1/0) attention mask, not an additive (0/-inf) mask. The model internally converts to additive, so passing 0/-inf caused double-masking and gibberish output. * refactor: consolidate training scripts into single train_block_refinement.py - Remove toy train_block_refinement_cap.py (self-contained demo with tiny model) - Rename train_block_refinement_qwen_cap.py to train_block_refinement.py (already works with any causal LM via AutoModelForCausalLM) - Fix torch_dtype deprecation and update README with correct script names * fix formatting * docs: improve LLaDA2 and BlockRefinement documentation - Add usage examples with real model IDs and working code - Add recommended parameters table for LLaDA2.1 quality/speed modes - Note that editing is LLaDA2.1-only (not for LLaDA2.0 models) - Remove misleading config defaults section from BlockRefinement docs * feat: set LLaDA2Pipeline defaults to recommended model parameters - threshold: 0.95 -> 0.7 (quality mode) - max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0) - eos_early_stop: False -> True (stop at EOS token) block_length=32, steps=32, temperature=0.0 were already correct. editing_threshold remains None (users enable for LLaDA2.1 models). * feat: default editing_threshold=0.5 for LLaDA2.1 quality mode LLaDA2.1 is the current generation. Users with LLaDA2.0 models can disable editing by passing editing_threshold=None. * fix: align sampling utilities with official LLaDA2 implementation - top_p filtering: add shift-right to preserve at least one token above threshold (matches official code line 1210) - temperature ordering: apply scaling before top-k/top-p filtering so filtering operates on scaled logits (matches official code lines 1232-1235) - greedy branch: return argmax directly when temperature=0 without filtering (matches official code lines 1226-1230) * refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids LLaDA2Pipeline._prepare_prompt_ids was a near-copy of DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate and call the mixin method directly. Also simplify _extract_input_ids since we always pass return_dict=True. * formatting * fix: replace deprecated torch_dtype with dtype in examples and docstrings - Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID - Fix sample_block_refinement.py to use dtype= * remove BlockRefinementPipeline * cleanup * fix readme * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * removed DiscreteDiffusionPipelineMixin * add support for 2d masks for flash attn * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix issues from review * added tests * formatting * add check_eos_finished to scheduler * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix renaming issues and types * remove duplicate check * Update docs/source/en/api/pipelines/llada2.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
904 lines
36 KiB
Python
904 lines
36 KiB
Python
import contextlib
|
|
import copy
|
|
import gc
|
|
import math
|
|
import random
|
|
import re
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
from typing import Any, Iterable
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
if getattr(torch, "distributed", None) is not None:
|
|
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
|
|
|
from .models import UNet2DConditionModel
|
|
from .pipelines import DiffusionPipeline
|
|
from .schedulers import SchedulerMixin
|
|
from .utils import (
|
|
convert_state_dict_to_diffusers,
|
|
convert_state_dict_to_peft,
|
|
deprecate,
|
|
is_accelerate_available,
|
|
is_peft_available,
|
|
is_torch_npu_available,
|
|
is_torchvision_available,
|
|
is_transformers_available,
|
|
)
|
|
|
|
|
|
if is_transformers_available():
|
|
import transformers
|
|
|
|
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
|
import deepspeed
|
|
|
|
if is_accelerate_available():
|
|
from accelerate.logging import get_logger
|
|
|
|
if is_peft_available():
|
|
from peft import set_peft_model_state_dict
|
|
|
|
if is_torchvision_available():
|
|
from torchvision import transforms
|
|
|
|
if is_torch_npu_available():
|
|
import torch_npu # noqa: F401
|
|
|
|
|
|
def set_seed(seed: int):
|
|
"""
|
|
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
|
|
|
Args:
|
|
seed (`int`): The seed to set.
|
|
|
|
Returns:
|
|
`None`
|
|
"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if is_torch_npu_available():
|
|
torch.npu.manual_seed_all(seed)
|
|
else:
|
|
torch.cuda.manual_seed_all(seed)
|
|
# ^^ safe to call this function even if cuda is not available
|
|
|
|
|
|
def compute_snr(noise_scheduler, timesteps):
|
|
"""
|
|
Computes SNR as per
|
|
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
|
for the given timesteps using the provided noise scheduler.
|
|
|
|
Args:
|
|
noise_scheduler (`NoiseScheduler`):
|
|
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
|
|
the SNR values.
|
|
timesteps (`torch.Tensor`):
|
|
A tensor of timesteps for which the SNR is computed.
|
|
|
|
Returns:
|
|
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
|
|
"""
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
|
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
|
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
|
|
|
# Expand the tensors.
|
|
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
|
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
|
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
|
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
|
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
|
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
|
|
|
# Compute SNR.
|
|
snr = (alpha / sigma) ** 2
|
|
return snr
|
|
|
|
|
|
def compute_confidence_aware_loss(
|
|
logits: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
*,
|
|
lambda_conf: float = 0.0,
|
|
temperature: float = 1.0,
|
|
per_token_weights: torch.Tensor | None = None,
|
|
ignore_index: int = -100,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Computes a confidence-aware training loss for token classification-style heads.
|
|
|
|
This loss combines:
|
|
- `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
|
|
- `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
|
|
|
|
Args:
|
|
logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
|
|
labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
|
|
are excluded from both losses.
|
|
lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
|
|
temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
|
|
sharpen the distribution and change the strength of the confidence gradients.
|
|
per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
|
|
token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
|
|
ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
|
|
|
|
Returns:
|
|
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
|
|
"""
|
|
if logits.ndim < 2:
|
|
raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
|
|
if labels.shape != logits.shape[:-1]:
|
|
raise ValueError(
|
|
f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
|
|
)
|
|
if temperature <= 0:
|
|
raise ValueError(f"`temperature` must be > 0, got {temperature}.")
|
|
|
|
valid = labels.ne(ignore_index)
|
|
if per_token_weights is None:
|
|
weights = torch.ones_like(labels, dtype=logits.dtype)
|
|
else:
|
|
if per_token_weights.shape != labels.shape:
|
|
raise ValueError(
|
|
f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
|
|
)
|
|
weights = per_token_weights.to(dtype=logits.dtype)
|
|
|
|
# Supervised CE (optionally weighted).
|
|
vocab_size = logits.shape[-1]
|
|
per_token_nll = F.cross_entropy(
|
|
logits.reshape(-1, vocab_size),
|
|
labels.reshape(-1),
|
|
reduction="none",
|
|
ignore_index=ignore_index,
|
|
).reshape_as(labels)
|
|
|
|
denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1)
|
|
loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft
|
|
|
|
# Confidence loss: penalize entropy only where prediction is already correct.
|
|
if lambda_conf == 0.0:
|
|
loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype)
|
|
return loss_sft, loss_sft, loss_conf
|
|
|
|
with torch.no_grad():
|
|
pred = logits.argmax(dim=-1)
|
|
correct = valid & pred.eq(labels)
|
|
|
|
scaled_logits = logits.float()
|
|
if temperature != 1.0:
|
|
scaled_logits = scaled_logits / float(temperature)
|
|
|
|
probs = torch.softmax(scaled_logits, dim=-1)
|
|
eps = torch.finfo(probs.dtype).tiny
|
|
log_probs = torch.log(probs.clamp_min(eps))
|
|
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
|
|
|
|
denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1)
|
|
loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf
|
|
|
|
loss = loss_sft + float(lambda_conf) * loss_conf
|
|
return loss, loss_sft, loss_conf
|
|
|
|
|
|
def resolve_interpolation_mode(interpolation_type: str):
|
|
"""
|
|
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
|
full list of supported enums is documented at
|
|
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
|
|
|
|
Args:
|
|
interpolation_type (`str`):
|
|
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
|
|
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
|
|
in torchvision.
|
|
|
|
Returns:
|
|
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
|
transform.
|
|
"""
|
|
if not is_torchvision_available():
|
|
raise ImportError(
|
|
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
|
|
)
|
|
|
|
if interpolation_type == "bilinear":
|
|
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
|
elif interpolation_type == "bicubic":
|
|
interpolation_mode = transforms.InterpolationMode.BICUBIC
|
|
elif interpolation_type == "box":
|
|
interpolation_mode = transforms.InterpolationMode.BOX
|
|
elif interpolation_type == "nearest":
|
|
interpolation_mode = transforms.InterpolationMode.NEAREST
|
|
elif interpolation_type == "nearest_exact":
|
|
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
|
|
elif interpolation_type == "hamming":
|
|
interpolation_mode = transforms.InterpolationMode.HAMMING
|
|
elif interpolation_type == "lanczos":
|
|
interpolation_mode = transforms.InterpolationMode.LANCZOS
|
|
else:
|
|
raise ValueError(
|
|
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
|
|
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
|
)
|
|
|
|
return interpolation_mode
|
|
|
|
|
|
def compute_dream_and_update_latents(
|
|
unet: UNet2DConditionModel,
|
|
noise_scheduler: SchedulerMixin,
|
|
timesteps: torch.Tensor,
|
|
noise: torch.Tensor,
|
|
noisy_latents: torch.Tensor,
|
|
target: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
dream_detail_preservation: float = 1.0,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from
|
|
https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more
|
|
efficient and accurate at the cost of an extra forward step without gradients.
|
|
|
|
Args:
|
|
`unet`: The state unet to use to make a prediction.
|
|
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
|
|
`timesteps`: The timesteps for the noise_scheduler to user.
|
|
`noise`: A tensor of noise in the shape of noisy_latents.
|
|
`noisy_latents`: Previously noise latents from the training loop.
|
|
`target`: The ground-truth tensor to predict after eps is removed.
|
|
`encoder_hidden_states`: Text embeddings from the text model.
|
|
`dream_detail_preservation`: A float value that indicates detail preservation level.
|
|
See reference.
|
|
|
|
Returns:
|
|
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
|
|
"""
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
|
|
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
|
|
|
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
|
|
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
|
|
|
|
pred = None
|
|
with torch.no_grad():
|
|
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
|
|
_noisy_latents, _target = (None, None)
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
predicted_noise = pred
|
|
delta_noise = (noise - predicted_noise).detach()
|
|
delta_noise.mul_(dream_lambda)
|
|
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
|
_target = target.add(delta_noise)
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
raise NotImplementedError("DREAM has not been implemented for v-prediction")
|
|
else:
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
return _noisy_latents, _target
|
|
|
|
|
|
def unet_lora_state_dict(unet: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
|
r"""
|
|
Returns:
|
|
A state dict containing just the LoRA parameters.
|
|
"""
|
|
lora_state_dict = {}
|
|
|
|
for name, module in unet.named_modules():
|
|
if hasattr(module, "set_lora_layer"):
|
|
lora_layer = getattr(module, "lora_layer")
|
|
if lora_layer is not None:
|
|
current_lora_layer_sd = lora_layer.state_dict()
|
|
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
|
|
# The matrix name can either be "down" or "up".
|
|
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
|
|
|
|
return lora_state_dict
|
|
|
|
|
|
def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32):
|
|
"""
|
|
Casts the training parameters of the model to the specified data type.
|
|
|
|
Args:
|
|
model: The PyTorch model whose parameters will be cast.
|
|
dtype: The data type to which the model parameters will be cast.
|
|
"""
|
|
if not isinstance(model, list):
|
|
model = [model]
|
|
for m in model:
|
|
for param in m.parameters():
|
|
# only upcast trainable parameters into fp32
|
|
if param.requires_grad:
|
|
param.data = param.to(dtype)
|
|
|
|
|
|
def _set_state_dict_into_text_encoder(
|
|
lora_state_dict: dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
|
|
):
|
|
"""
|
|
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
|
|
|
|
Args:
|
|
lora_state_dict: The state dictionary to be set.
|
|
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
|
|
text_encoder: Where the `lora_state_dict` is to be set.
|
|
"""
|
|
|
|
text_encoder_state_dict = {
|
|
f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
|
|
}
|
|
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
|
|
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
|
|
|
|
|
|
def _collate_lora_metadata(modules_to_save: dict[str, torch.nn.Module]) -> dict[str, Any]:
|
|
metadatas = {}
|
|
for module_name, module in modules_to_save.items():
|
|
if module is not None:
|
|
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
|
|
return metadatas
|
|
|
|
|
|
def compute_density_for_timestep_sampling(
|
|
weighting_scheme: str,
|
|
batch_size: int,
|
|
logit_mean: float = None,
|
|
logit_std: float = None,
|
|
mode_scale: float = None,
|
|
device: torch.device | str = "cpu",
|
|
generator: torch.Generator | None = None,
|
|
):
|
|
"""
|
|
Compute the density for sampling the timesteps when doing SD3 training.
|
|
|
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
|
|
|
SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
|
|
"""
|
|
if weighting_scheme == "logit_normal":
|
|
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
|
|
u = torch.nn.functional.sigmoid(u)
|
|
elif weighting_scheme == "mode":
|
|
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
|
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
|
else:
|
|
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
|
return u
|
|
|
|
|
|
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
|
"""
|
|
Computes loss weighting scheme for SD3 training.
|
|
|
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
|
|
|
SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
|
|
"""
|
|
if weighting_scheme == "sigma_sqrt":
|
|
weighting = (sigmas**-2.0).float()
|
|
elif weighting_scheme == "cosmap":
|
|
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
|
weighting = 2 / (math.pi * bot)
|
|
else:
|
|
weighting = torch.ones_like(sigmas)
|
|
return weighting
|
|
|
|
|
|
def free_memory():
|
|
"""
|
|
Runs garbage collection. Then clears the cache of the available accelerator.
|
|
"""
|
|
gc.collect()
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
elif torch.backends.mps.is_available():
|
|
torch.mps.empty_cache()
|
|
elif is_torch_npu_available():
|
|
torch_npu.npu.empty_cache()
|
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
torch.xpu.empty_cache()
|
|
|
|
|
|
@contextmanager
|
|
def offload_models(*modules: torch.nn.Module | DiffusionPipeline, device: str | torch.device, offload: bool = True):
|
|
"""
|
|
Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
|
|
device on exit.
|
|
|
|
Args:
|
|
device (`str` or `torch.Device`): Device to move the `modules` to.
|
|
offload (`bool`): Flag to enable offloading.
|
|
"""
|
|
if offload:
|
|
is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
|
|
# record where each module was
|
|
if is_model:
|
|
original_devices = [next(m.parameters()).device for m in modules]
|
|
else:
|
|
assert len(modules) == 1
|
|
# For DiffusionPipeline, wrap the device in a list to make it iterable
|
|
original_devices = [modules[0].device]
|
|
# move to target device
|
|
for m in modules:
|
|
m.to(device)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
if offload:
|
|
# move back to original devices
|
|
for m, orig_dev in zip(modules, original_devices):
|
|
m.to(orig_dev)
|
|
|
|
|
|
def parse_buckets_string(buckets_str):
|
|
"""Parses a string defining buckets into a list of (height, width) tuples."""
|
|
if not buckets_str:
|
|
raise ValueError("Bucket string cannot be empty.")
|
|
|
|
bucket_pairs = buckets_str.strip().split(";")
|
|
parsed_buckets = []
|
|
for pair_str in bucket_pairs:
|
|
match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
|
|
if not match:
|
|
raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.")
|
|
try:
|
|
height = int(match.group(1))
|
|
width = int(match.group(2))
|
|
if height <= 0 or width <= 0:
|
|
raise ValueError("Bucket dimensions must be positive integers.")
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.")
|
|
parsed_buckets.append((height, width))
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e
|
|
|
|
if not parsed_buckets:
|
|
raise ValueError("No valid buckets found in the provided string.")
|
|
|
|
return parsed_buckets
|
|
|
|
|
|
def find_nearest_bucket(h, w, bucket_options):
|
|
"""Finds the closes bucket to the given height and width."""
|
|
min_metric = float("inf")
|
|
best_bucket_idx = None
|
|
for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
|
|
metric = abs(h * bucket_w - w * bucket_h)
|
|
if metric <= min_metric:
|
|
min_metric = metric
|
|
best_bucket_idx = bucket_idx
|
|
return best_bucket_idx
|
|
|
|
|
|
def _to_cpu_contiguous(state_dicts) -> dict:
|
|
return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()}
|
|
|
|
|
|
def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
|
|
"""
|
|
Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
|
|
"""
|
|
|
|
kwargs = {}
|
|
fsdp_state = getattr(accelerator.state, "fsdp_plugin", None)
|
|
|
|
if fsdp_state is None:
|
|
raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.")
|
|
|
|
fsdp_plugin = accelerator.state.fsdp_plugin
|
|
|
|
if fsdp_plugin is None:
|
|
# FSDP not enabled in Accelerator
|
|
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
|
|
else:
|
|
# FSDP is enabled → use plugin's strategy, or default if None
|
|
kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
|
|
|
|
return kwargs
|
|
|
|
|
|
def wrap_with_fsdp(
|
|
model: torch.nn.Module,
|
|
device: str | torch.device,
|
|
offload: bool = True,
|
|
use_orig_params: bool = True,
|
|
limit_all_gathers: bool = True,
|
|
fsdp_kwargs: dict[str, Any] | None = None,
|
|
transformer_layer_cls: set[type[torch.nn.Module]] | None = None,
|
|
) -> FSDP:
|
|
"""
|
|
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
|
|
|
|
Args:
|
|
model: Model to wrap
|
|
device: Target device (e.g., accelerator.device)
|
|
offload: Whether to enable CPU parameter offloading
|
|
use_orig_params: Whether to use original parameters
|
|
limit_all_gathers: Whether to limit all gathers
|
|
fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
|
|
transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)
|
|
|
|
Returns:
|
|
FSDP-wrapped model
|
|
"""
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
if transformer_layer_cls is None:
|
|
# Set the default layers if transformer_layer_cls is not provided
|
|
transformer_layer_cls = type(model.model.language_model.layers[0])
|
|
logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}")
|
|
|
|
# Add auto-wrap policy if transformer layers specified
|
|
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={transformer_layer_cls})
|
|
|
|
config = {
|
|
"device_id": device,
|
|
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
|
|
"use_orig_params": use_orig_params,
|
|
"limit_all_gathers": limit_all_gathers,
|
|
"auto_wrap_policy": auto_wrap_policy,
|
|
}
|
|
|
|
if fsdp_kwargs:
|
|
config.update(fsdp_kwargs)
|
|
|
|
fsdp_model = FSDP(model, **config)
|
|
return fsdp_model
|
|
|
|
|
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
|
class EMAModel:
|
|
"""
|
|
Exponential Moving Average of models weights
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
parameters: Iterable[torch.nn.Parameter],
|
|
decay: float = 0.9999,
|
|
min_decay: float = 0.0,
|
|
update_after_step: int = 0,
|
|
use_ema_warmup: bool = False,
|
|
inv_gamma: float | int = 1.0,
|
|
power: float | int = 2 / 3,
|
|
foreach: bool = False,
|
|
model_cls: Any | None = None,
|
|
model_config: dict[str, Any] | None = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
|
decay (float): The decay factor for the exponential moving average.
|
|
min_decay (float): The minimum decay factor for the exponential moving average.
|
|
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
|
use_ema_warmup (bool): Whether to use EMA warmup.
|
|
inv_gamma (float):
|
|
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
|
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
|
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
|
|
device (str | torch.device | None): The device to store the EMA weights on. If None, the EMA
|
|
weights will be stored on CPU.
|
|
|
|
@crowsonkb's notes on EMA Warmup:
|
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
|
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
|
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
|
at 215.4k steps).
|
|
"""
|
|
|
|
if isinstance(parameters, torch.nn.Module):
|
|
deprecation_message = (
|
|
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
|
|
"Please pass the parameters of the module instead."
|
|
)
|
|
deprecate(
|
|
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
|
|
"1.0.0",
|
|
deprecation_message,
|
|
standard_warn=False,
|
|
)
|
|
parameters = parameters.parameters()
|
|
|
|
# set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
|
|
use_ema_warmup = True
|
|
|
|
if kwargs.get("max_value", None) is not None:
|
|
deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
|
|
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
|
|
decay = kwargs["max_value"]
|
|
|
|
if kwargs.get("min_value", None) is not None:
|
|
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
|
|
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
|
|
min_decay = kwargs["min_value"]
|
|
|
|
parameters = list(parameters)
|
|
self.shadow_params = [p.clone().detach() for p in parameters]
|
|
|
|
if kwargs.get("device", None) is not None:
|
|
deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
|
|
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
|
|
self.to(device=kwargs["device"])
|
|
|
|
self.temp_stored_params = None
|
|
|
|
self.decay = decay
|
|
self.min_decay = min_decay
|
|
self.update_after_step = update_after_step
|
|
self.use_ema_warmup = use_ema_warmup
|
|
self.inv_gamma = inv_gamma
|
|
self.power = power
|
|
self.optimization_step = 0
|
|
self.cur_decay_value = None # set in `step()`
|
|
self.foreach = foreach
|
|
|
|
self.model_cls = model_cls
|
|
self.model_config = model_config
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
|
|
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
|
|
model = model_cls.from_pretrained(path)
|
|
|
|
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
|
|
|
|
ema_model.load_state_dict(ema_kwargs)
|
|
return ema_model
|
|
|
|
def save_pretrained(self, path):
|
|
if self.model_cls is None:
|
|
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
|
|
|
if self.model_config is None:
|
|
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
|
|
|
model = self.model_cls.from_config(self.model_config)
|
|
state_dict = self.state_dict()
|
|
state_dict.pop("shadow_params", None)
|
|
|
|
model.register_to_config(**state_dict)
|
|
self.copy_to(model.parameters())
|
|
model.save_pretrained(path)
|
|
|
|
def get_decay(self, optimization_step: int) -> float:
|
|
"""
|
|
Compute the decay factor for the exponential moving average.
|
|
"""
|
|
step = max(0, optimization_step - self.update_after_step - 1)
|
|
|
|
if step <= 0:
|
|
return 0.0
|
|
|
|
if self.use_ema_warmup:
|
|
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
|
else:
|
|
cur_decay_value = (1 + step) / (10 + step)
|
|
|
|
cur_decay_value = min(cur_decay_value, self.decay)
|
|
# make sure decay is not smaller than min_decay
|
|
cur_decay_value = max(cur_decay_value, self.min_decay)
|
|
return cur_decay_value
|
|
|
|
@torch.no_grad()
|
|
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
|
if isinstance(parameters, torch.nn.Module):
|
|
deprecation_message = (
|
|
"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
|
|
"Please pass the parameters of the module instead."
|
|
)
|
|
deprecate(
|
|
"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
|
|
"1.0.0",
|
|
deprecation_message,
|
|
standard_warn=False,
|
|
)
|
|
parameters = parameters.parameters()
|
|
|
|
parameters = list(parameters)
|
|
|
|
self.optimization_step += 1
|
|
|
|
# Compute the decay factor for the exponential moving average.
|
|
decay = self.get_decay(self.optimization_step)
|
|
self.cur_decay_value = decay
|
|
one_minus_decay = 1 - decay
|
|
|
|
context_manager = contextlib.nullcontext()
|
|
|
|
if self.foreach:
|
|
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
|
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
|
|
|
|
with context_manager:
|
|
params_grad = [param for param in parameters if param.requires_grad]
|
|
s_params_grad = [
|
|
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
|
|
]
|
|
|
|
if len(params_grad) < len(parameters):
|
|
torch._foreach_copy_(
|
|
[s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
|
|
[param for param in parameters if not param.requires_grad],
|
|
non_blocking=True,
|
|
)
|
|
|
|
torch._foreach_sub_(
|
|
s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
|
|
)
|
|
|
|
else:
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
|
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
|
|
|
|
with context_manager:
|
|
if param.requires_grad:
|
|
s_param.sub_(one_minus_decay * (s_param - param))
|
|
else:
|
|
s_param.copy_(param)
|
|
|
|
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
|
"""
|
|
Copy current averaged parameters into given collection of parameters.
|
|
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
updated with the stored moving averages. If `None`, the parameters with which this
|
|
`ExponentialMovingAverage` was initialized will be used.
|
|
"""
|
|
parameters = list(parameters)
|
|
if self.foreach:
|
|
torch._foreach_copy_(
|
|
[param.data for param in parameters],
|
|
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
|
|
)
|
|
else:
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
param.data.copy_(s_param.to(param.device).data)
|
|
|
|
def pin_memory(self) -> None:
|
|
r"""
|
|
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
|
|
offloading EMA params to the host.
|
|
"""
|
|
|
|
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
|
|
|
|
def to(self, device=None, dtype=None, non_blocking=False) -> None:
|
|
r"""
|
|
Move internal buffers of the ExponentialMovingAverage to `device`.
|
|
|
|
Args:
|
|
device: like `device` argument to `torch.Tensor.to`
|
|
"""
|
|
# .to() on the tensors handles None correctly
|
|
self.shadow_params = [
|
|
p.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
|
if p.is_floating_point()
|
|
else p.to(device=device, non_blocking=non_blocking)
|
|
for p in self.shadow_params
|
|
]
|
|
|
|
def state_dict(self) -> dict:
|
|
r"""
|
|
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
|
checkpointing to save the ema state dict.
|
|
"""
|
|
# Following PyTorch conventions, references to tensors are returned:
|
|
# "returns a reference to the state and not its copy!" -
|
|
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
|
return {
|
|
"decay": self.decay,
|
|
"min_decay": self.min_decay,
|
|
"optimization_step": self.optimization_step,
|
|
"update_after_step": self.update_after_step,
|
|
"use_ema_warmup": self.use_ema_warmup,
|
|
"inv_gamma": self.inv_gamma,
|
|
"power": self.power,
|
|
"shadow_params": self.shadow_params,
|
|
}
|
|
|
|
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
|
r"""
|
|
Saves the current parameters for restoring later.
|
|
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
|
|
"""
|
|
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
|
|
|
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
|
r"""
|
|
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
|
|
without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
|
validation (or model saving), use this to restore the former parameters.
|
|
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
updated with the stored parameters. If `None`, the parameters with which this
|
|
`ExponentialMovingAverage` was initialized will be used.
|
|
"""
|
|
|
|
if self.temp_stored_params is None:
|
|
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
|
if self.foreach:
|
|
torch._foreach_copy_(
|
|
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
|
|
)
|
|
else:
|
|
for c_param, param in zip(self.temp_stored_params, parameters):
|
|
param.data.copy_(c_param.data)
|
|
|
|
# Better memory-wise.
|
|
self.temp_stored_params = None
|
|
|
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
r"""
|
|
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
|
ema state dict.
|
|
|
|
Args:
|
|
state_dict (dict): EMA state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
# deepcopy, to be consistent with module API
|
|
state_dict = copy.deepcopy(state_dict)
|
|
|
|
self.decay = state_dict.get("decay", self.decay)
|
|
if self.decay < 0.0 or self.decay > 1.0:
|
|
raise ValueError("Decay must be between 0 and 1")
|
|
|
|
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
|
if not isinstance(self.min_decay, float):
|
|
raise ValueError("Invalid min_decay")
|
|
|
|
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
|
if not isinstance(self.optimization_step, int):
|
|
raise ValueError("Invalid optimization_step")
|
|
|
|
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
|
if not isinstance(self.update_after_step, int):
|
|
raise ValueError("Invalid update_after_step")
|
|
|
|
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
|
if not isinstance(self.use_ema_warmup, bool):
|
|
raise ValueError("Invalid use_ema_warmup")
|
|
|
|
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
|
if not isinstance(self.inv_gamma, (float, int)):
|
|
raise ValueError("Invalid inv_gamma")
|
|
|
|
self.power = state_dict.get("power", self.power)
|
|
if not isinstance(self.power, (float, int)):
|
|
raise ValueError("Invalid power")
|
|
|
|
shadow_params = state_dict.get("shadow_params", None)
|
|
if shadow_params is not None:
|
|
self.shadow_params = shadow_params
|
|
if not isinstance(self.shadow_params, list):
|
|
raise ValueError("shadow_params must be a list")
|
|
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
|
raise ValueError("shadow_params must all be Tensors")
|