mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 14:04:37 +08:00
Compare commits
7 Commits
faster-loa
...
add-stocha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff1012f8cf | ||
|
|
25bc77d8f8 | ||
|
|
9c35a89921 | ||
|
|
32d9aef997 | ||
|
|
9edc5beddc | ||
|
|
f87956e9cf | ||
|
|
690adb5bd9 |
@@ -1915,22 +1915,17 @@ def main(args):
|
||||
free_memory()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1954,6 +1949,7 @@ def main(args):
|
||||
lr_scheduler,
|
||||
)
|
||||
else:
|
||||
print("I SHOULD BE HERE")
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
@@ -1965,14 +1961,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -1407,22 +1407,17 @@ def main(args):
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1449,14 +1444,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -1524,22 +1524,17 @@ def main(args):
|
||||
free_memory()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1566,14 +1561,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -1523,22 +1523,17 @@ def main(args):
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1555,14 +1550,7 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
@@ -18,7 +18,7 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from zipfile import is_zipfile
|
||||
@@ -38,7 +38,6 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerator_device,
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
@@ -305,51 +304,6 @@ def load_model_dict_into_meta(
|
||||
return offload_index, state_dict_index
|
||||
|
||||
|
||||
# Taken from
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
|
||||
def _expand_device_map(device_map, param_names):
|
||||
new_device_map = {}
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
)
|
||||
return new_device_map
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
|
||||
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
|
||||
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
|
||||
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
|
||||
}
|
||||
if not len(accelerator_device_map):
|
||||
return
|
||||
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
total_byte_count[device] += param_byte_count
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, byte_count in total_byte_count.items():
|
||||
if device.type == "cuda":
|
||||
index = device.index if device.index is not None else torch.cuda.current_device()
|
||||
device_memory = torch.cuda.mem_get_info(index)[0]
|
||||
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
|
||||
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
|
||||
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
|
||||
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
|
||||
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
|
||||
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
|
||||
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
|
||||
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
|
||||
# Allocate memory
|
||||
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def _load_state_dict_into_model(
|
||||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
||||
) -> List[str]:
|
||||
|
||||
@@ -63,9 +63,7 @@ from ..utils.hub_utils import (
|
||||
populate_model_card,
|
||||
)
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
_expand_device_map,
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_load_state_dict_into_model,
|
||||
@@ -1376,24 +1374,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
return super().float(*args)
|
||||
|
||||
# Taken from `transformers`.
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
|
||||
def get_parameter_or_buffer(self, target: str):
|
||||
"""
|
||||
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
|
||||
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
|
||||
of the model.
|
||||
"""
|
||||
try:
|
||||
return self.get_parameter(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
return self.get_buffer(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
@@ -1430,11 +1410,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
# Optionally, warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None:
|
||||
expanded_device_map = _expand_device_map(device_map, expected_keys)
|
||||
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
# Deal with offload
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
if offload_folder is None:
|
||||
|
||||
@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
|
||||
time_shift_type (`str`, defaults to "exponential"):
|
||||
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
|
||||
stochastic_sampling (`bool`, defaults to False):
|
||||
Whether to use stochastic sampling.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -101,6 +103,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
time_shift_type: str = "exponential",
|
||||
stochastic_sampling: bool = False,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -437,13 +440,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
||||
lower_sigmas = lower_mask * sigmas
|
||||
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
||||
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
||||
|
||||
current_sigma = per_token_sigmas[..., None]
|
||||
next_sigma = lower_sigmas[..., None]
|
||||
dt = current_sigma - next_sigma
|
||||
else:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
sigma_idx = self.step_index
|
||||
sigma = self.sigmas[sigma_idx]
|
||||
sigma_next = self.sigmas[sigma_idx + 1]
|
||||
|
||||
current_sigma = sigma
|
||||
next_sigma = sigma_next
|
||||
dt = sigma_next - sigma
|
||||
|
||||
prev_sample = sample + dt * model_output
|
||||
if self.config.stochastic_sampling:
|
||||
x0 = sample - current_sigma * model_output
|
||||
noise = torch.randn_like(sample)
|
||||
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
|
||||
else:
|
||||
prev_sample = sample + dt * model_output
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
@@ -129,7 +129,6 @@ from .state_dict_utils import (
|
||||
convert_unet_state_dict_to_peft,
|
||||
state_dict_all_zero,
|
||||
)
|
||||
from .testing_utils import is_accelerator_device
|
||||
from .typing_utils import _get_detailed_type, _is_valid_type
|
||||
|
||||
|
||||
|
||||
@@ -1289,18 +1289,6 @@ if is_torch_available():
|
||||
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
# Taken from
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5864C1-L5871C64
|
||||
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
||||
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
|
||||
a proper `torch.device`.
|
||||
"""
|
||||
if device == "disk":
|
||||
return False
|
||||
else:
|
||||
return torch.device(device).type not in ["meta", "cpu"]
|
||||
|
||||
# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090
|
||||
|
||||
# Type definition of key used in `Expectations` class.
|
||||
|
||||
Reference in New Issue
Block a user