mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-15 20:27:07 +08:00
Compare commits
6 Commits
unet-model
...
bnb-test-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe38d77603 | ||
|
|
526498d219 | ||
|
|
6a339ce637 | ||
|
|
26bb7fa0cb | ||
|
|
5063aa5566 | ||
|
|
62b1071609 |
@@ -1533,9 +1533,9 @@ def main(args):
|
||||
# from the cat above, but collate_fn also doubles the prompts list. Use half the
|
||||
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
|
||||
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_repeat_elements, dim=0)
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
model_input = latents_cache[step].sample()
|
||||
@@ -1602,10 +1602,11 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
|
||||
2
setup.py
2
setup.py
@@ -146,6 +146,7 @@ _deps = [
|
||||
"phonemizer",
|
||||
"opencv-python",
|
||||
"timm",
|
||||
"flashpack",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -250,6 +251,7 @@ extras["gguf"] = deps_list("gguf", "accelerate")
|
||||
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
||||
extras["torchao"] = deps_list("torchao", "accelerate")
|
||||
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
|
||||
extras["flashpack"] = deps_list("flashpack")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
|
||||
@@ -53,4 +53,5 @@ deps = {
|
||||
"phonemizer": "phonemizer",
|
||||
"opencv-python": "opencv-python",
|
||||
"timm": "timm",
|
||||
"flashpack": "flashpack",
|
||||
}
|
||||
|
||||
@@ -540,7 +540,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_version(">=", "0.12.3"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
@@ -42,6 +42,7 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
FLASHPACK_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HF_ENABLE_PARALLEL_LOADING,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -55,6 +56,7 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_version,
|
||||
is_flashpack_available,
|
||||
is_peft_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
@@ -673,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant: str | None = None,
|
||||
max_shard_size: int | str = "10GB",
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -725,7 +728,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
||||
)
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = WEIGHTS_NAME
|
||||
if use_flashpack:
|
||||
weights_name = FLASHPACK_WEIGHTS_NAME
|
||||
elif safe_serialization:
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME
|
||||
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
@@ -752,58 +760,74 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Save the model
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
if use_flashpack:
|
||||
if is_flashpack_available():
|
||||
import flashpack
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
logger.error(
|
||||
"Saving a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
|
||||
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
|
||||
)
|
||||
raise ImportError("Please install torch and flashpack to save a FlashPack checkpoint in PyTorch.")
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
flashpack.serialization.pack_to_file(
|
||||
state_dict_or_model=state_dict,
|
||||
destination_path=os.path.join(save_directory, weights_name),
|
||||
target_dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
# Save the model
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
@@ -940,6 +964,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is loaded from `flashpack` weights.
|
||||
flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||
Kwargs passed to
|
||||
[`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)
|
||||
|
||||
|
||||
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
|
||||
with `hf > auth login`. You can also activate the special >
|
||||
@@ -984,6 +1014,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries: dict[str, DDUFEntry] | None = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
parallel_config: ParallelConfig | ContextParallelConfig | None = kwargs.pop("parallel_config", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
flashpack_kwargs = kwargs.pop("flashpack_kwargs", {})
|
||||
|
||||
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
||||
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
||||
@@ -1212,30 +1244,37 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder or "",
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
elif use_safetensors:
|
||||
try:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
else:
|
||||
if use_flashpack:
|
||||
weights_name = FLASHPACK_WEIGHTS_NAME
|
||||
elif use_safetensors:
|
||||
weights_name = _add_variant(SAFETENSORS_WEIGHTS_NAME, variant)
|
||||
else:
|
||||
weights_name = None
|
||||
if weights_name is not None:
|
||||
try:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
||||
if not allow_pickle:
|
||||
raise
|
||||
logger.warning(
|
||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||
)
|
||||
except IOError as e:
|
||||
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
||||
if not allow_pickle:
|
||||
raise
|
||||
logger.warning(
|
||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||
)
|
||||
|
||||
if resolved_model_file is None and not is_sharded:
|
||||
resolved_model_file = _get_model_file(
|
||||
@@ -1275,6 +1314,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
with ContextManagers(init_contexts):
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
if use_flashpack:
|
||||
if is_flashpack_available():
|
||||
import flashpack
|
||||
else:
|
||||
logger.error(
|
||||
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
|
||||
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
|
||||
)
|
||||
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
|
||||
|
||||
if device_map is None:
|
||||
logger.warning(
|
||||
"`device_map` has not been provided for FlashPack, model will be on `cpu` - provide `device_map` to fully utilize "
|
||||
"the benefit of FlashPack."
|
||||
)
|
||||
flashpack_device = torch.device("cpu")
|
||||
else:
|
||||
device = device_map[""]
|
||||
if isinstance(device, str) and device in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||
raise ValueError(
|
||||
"FlashPack `device_map` should not be one of `auto`, `balanced`, `balanced_low_0`, `sequential`. Use a specific device instead, e.g., `device_map='cuda'` or `device_map='cuda:0'"
|
||||
)
|
||||
flashpack_device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||
|
||||
flashpack.mixin.assign_from_file(
|
||||
model=model,
|
||||
path=resolved_model_file[0],
|
||||
device=flashpack_device,
|
||||
**flashpack_kwargs,
|
||||
)
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
if output_loading_info:
|
||||
logger.warning("`output_loading_info` is not supported with FlashPack.")
|
||||
return model, {}
|
||||
|
||||
return model
|
||||
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class ErnieImageTransformer2DModelOutput(BaseOutput):
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
return out.float()
|
||||
@@ -400,8 +400,8 @@ class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
]
|
||||
|
||||
# AdaLN
|
||||
sample = self.time_proj(timestep.to(dtype))
|
||||
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
|
||||
sample = self.time_proj(timestep)
|
||||
sample = sample.to(dtype=dtype)
|
||||
c = self.time_embedding(sample)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||
t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
|
||||
@@ -445,10 +445,14 @@ class WanAnimateFaceBlockAttnProcessor:
|
||||
# B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
|
||||
B, T, N, C = encoder_hidden_states.shape
|
||||
|
||||
# Flatten T and N so the K/V projections see a 3D tensor; BnB int8 matmul only
|
||||
# accepts 2D/3D inputs and would otherwise fail on this 4D activation.
|
||||
encoder_hidden_states = encoder_hidden_states.flatten(1, 2) # [B, T, N, C] --> [B, T * N, C]
|
||||
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
|
||||
key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
|
||||
key = key.view(B, T, N, attn.heads, -1) # [B, T * N, H * D_kv] --> [B, T, N, H, D_kv]
|
||||
value = value.view(B, T, N, attn.heads, -1)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
|
||||
@@ -877,10 +877,7 @@ class FluxPipeline(
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
timestep_device = "cpu"
|
||||
else:
|
||||
timestep_device = device
|
||||
timestep_device = device
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
|
||||
@@ -28,6 +28,7 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
FLASHPACK_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
@@ -194,6 +195,7 @@ def filter_model_files(filenames):
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
FLASHPACK_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -413,6 +415,9 @@ def get_class_obj_and_candidates(
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
||||
|
||||
if class_name.startswith("FlashPack"):
|
||||
class_name = class_name.removeprefix("FlashPack")
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
@@ -760,6 +765,7 @@ def load_sub_model(
|
||||
provider_options: Any,
|
||||
disable_mmap: bool,
|
||||
quantization_config: Any | None = None,
|
||||
use_flashpack: bool = False,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
from ..quantizers import PipelineQuantizationConfig
|
||||
@@ -838,6 +844,9 @@ def load_sub_model(
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
loading_kwargs["use_safetensors"] = use_safetensors
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
@@ -887,7 +896,7 @@ def load_sub_model(
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
|
||||
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict) and not use_flashpack:
|
||||
# remove hooks
|
||||
remove_hook_from_module(loaded_sub_model, recurse=True)
|
||||
needs_offloading_to_cpu = device_map[""] == "cpu"
|
||||
@@ -1093,6 +1102,7 @@ def _get_ignore_patterns(
|
||||
allow_pickle: bool,
|
||||
use_onnx: bool,
|
||||
is_onnx: bool,
|
||||
use_flashpack: bool,
|
||||
variant: str | None = None,
|
||||
) -> list[str]:
|
||||
if (
|
||||
@@ -1118,6 +1128,9 @@ def _get_ignore_patterns(
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
elif use_flashpack:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb", "*.msgpack"]
|
||||
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
|
||||
@@ -244,6 +244,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant: str | None = None,
|
||||
max_shard_size: int | str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -341,6 +342,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
|
||||
save_method_accept_peft_format = "save_peft_format" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
@@ -351,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
# max_shard_size is expected to not be None in ModelMixin
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
if save_method_accept_flashpack:
|
||||
save_kwargs["use_flashpack"] = use_flashpack
|
||||
if save_method_accept_peft_format:
|
||||
# Set save_peft_format=False for transformers>=5.0.0 compatibility
|
||||
# In transformers 5.0.0+, the default save_peft_format=True adds "base_model.model" prefix
|
||||
@@ -781,6 +785,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
|
||||
@@ -1071,6 +1076,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
provider_options=provider_options,
|
||||
disable_mmap=disable_mmap,
|
||||
quantization_config=quantization_config,
|
||||
use_flashpack=use_flashpack,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
@@ -1576,6 +1582,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
|
||||
option should only be set to `True` for repositories you trust and in which you have read the code, as
|
||||
it will execute code present on the Hub on your local machine.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, FlashPack weights will always be downloaded if present. If set to `False`, FlashPack
|
||||
weights will never be downloaded.
|
||||
|
||||
Returns:
|
||||
`os.PathLike`:
|
||||
@@ -1600,6 +1609,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
dduf_file: dict[str, DDUFEntry] | None = kwargs.pop("dduf_file", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
|
||||
if dduf_file:
|
||||
if custom_pipeline:
|
||||
@@ -1719,6 +1729,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
allow_pickle,
|
||||
use_onnx,
|
||||
pipeline_class._is_onnx,
|
||||
use_flashpack,
|
||||
variant,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ from .constants import (
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
DIFFUSERS_LOAD_ID_FIELDS,
|
||||
FLASHPACK_FILE_EXTENSION,
|
||||
FLASHPACK_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
GGUF_FILE_EXTENSION,
|
||||
HF_ENABLE_PARALLEL_LOADING,
|
||||
@@ -76,6 +78,7 @@ from .import_utils import (
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
is_flashpack_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_gguf_available,
|
||||
|
||||
@@ -34,6 +34,8 @@ ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
|
||||
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
||||
FLASHPACK_WEIGHTS_NAME = "model.flashpack"
|
||||
FLASHPACK_FILE_EXTENSION = "flashpack"
|
||||
GGUF_FILE_EXTENSION = "gguf"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||
|
||||
@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
|
||||
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
|
||||
|
||||
@@ -361,6 +362,10 @@ def is_gguf_available():
|
||||
return _gguf_available
|
||||
|
||||
|
||||
def is_flashpack_available():
|
||||
return _flashpack_available
|
||||
|
||||
|
||||
def is_torchao_available():
|
||||
return _torchao_available
|
||||
|
||||
|
||||
@@ -205,6 +205,11 @@ class BaseModelTesterConfig:
|
||||
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def torch_dtype(self) -> torch.dtype:
|
||||
"""Compute dtype used to build dummy inputs and cast inputs where needed."""
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Optional[tuple]:
|
||||
"""Expected output shape for output validation tests."""
|
||||
|
||||
@@ -359,15 +359,7 @@ class QuantizationTesterMixin:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
|
||||
|
||||
# Get model dtype from first parameter
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
# Cast inputs to model dtype
|
||||
inputs = {
|
||||
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None after dequantization"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
|
||||
@@ -575,33 +567,28 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_bnb_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
if not fp32_modules:
|
||||
pytest.skip(f"{self.model_class.__name__} does not declare _keep_in_fp32_modules")
|
||||
|
||||
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]
|
||||
|
||||
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
model.to(torch_device)
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, (
|
||||
f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
)
|
||||
else:
|
||||
assert module.weight.dtype == torch.uint8, (
|
||||
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, (
|
||||
f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
)
|
||||
else:
|
||||
assert module.weight.dtype == torch.uint8, (
|
||||
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
finally:
|
||||
if original_fp32_modules is not None:
|
||||
self.model_class._keep_in_fp32_modules = original_fp32_modules
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
|
||||
def test_bnb_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
|
||||
@@ -320,6 +320,51 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
"""Quanto quantization tests for Flux Transformer."""
|
||||
|
||||
74
tests/others/test_flashpack.py
Normal file
74
tests/others/test_flashpack.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.models.auto_model import AutoModel
|
||||
|
||||
from ..testing_utils import is_torch_available, require_flashpack, require_torch_gpu
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FlashPackTests(unittest.TestCase):
|
||||
model_id: str = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@require_flashpack
|
||||
def test_save_load_model(self):
|
||||
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||
self.assertTrue((pathlib.Path(temp_dir) / "model.flashpack").exists())
|
||||
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True)
|
||||
|
||||
@require_flashpack
|
||||
def test_save_load_pipeline(self):
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(self.model_id)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
pipeline.save_pretrained(temp_dir, use_flashpack=True)
|
||||
self.assertTrue((pathlib.Path(temp_dir) / "transformer" / "model.flashpack").exists())
|
||||
self.assertTrue((pathlib.Path(temp_dir) / "vae" / "model.flashpack").exists())
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(temp_dir, use_flashpack=True)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_flashpack
|
||||
def test_load_model_device_str(self):
|
||||
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "cuda"})
|
||||
self.assertTrue(model.device.type == "cuda")
|
||||
|
||||
@require_torch_gpu
|
||||
@require_flashpack
|
||||
def test_load_model_device(self):
|
||||
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": torch.device("cuda")})
|
||||
self.assertTrue(model.device.type == "cuda")
|
||||
|
||||
@require_flashpack
|
||||
def test_load_model_device_auto(self):
|
||||
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "auto"})
|
||||
@@ -34,6 +34,7 @@ from diffusers.utils.import_utils import (
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_compel_available,
|
||||
is_flashpack_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_kernels_available,
|
||||
@@ -737,6 +738,13 @@ def require_accelerate(test_case):
|
||||
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_flashpack(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires flashpack. These tests are skipped when flashpack isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_flashpack_available(), reason="test requires flashpack")(test_case)
|
||||
|
||||
|
||||
def require_peft_version_greater(peft_version):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
||||
|
||||
Reference in New Issue
Block a user