Compare commits

..

9 Commits

Author SHA1 Message Date
Sayak Paul
63afbae360 Merge branch 'main' into autoencoderkl-tests-refactor 2026-04-12 12:23:22 +05:30
Sayak Paul
bbe044acd9 Merge branch 'main' into autoencoderkl-tests-refactor 2026-04-06 15:39:01 +02:00
Sayak Paul
4529468ca8 Merge branch 'main' into autoencoderkl-tests-refactor 2026-04-06 10:24:31 +02:00
sayakpaul
5e9da5e998 up 2026-04-03 07:49:08 +02:00
Sayak Paul
bd0af5096e Merge branch 'main' into autoencoderkl-tests-refactor 2026-04-03 11:11:35 +05:30
sayakpaul
771174ac68 confirm coverage 2026-03-30 15:41:14 +05:30
Sayak Paul
d37402866c Merge branch 'main' into autoencoderkl-tests-refactor 2026-03-30 15:19:52 +05:30
sayakpaul
d10190b1b7 fix tests 2026-03-30 14:58:08 +05:30
sayakpaul
a1c3e6ccbb refactor autoencoderkl tests 2026-03-30 13:33:26 +05:30
19 changed files with 153 additions and 394 deletions

View File

@@ -1533,9 +1533,9 @@ def main(args):
# from the cat above, but collate_fn also doubles the prompts list. Use half the
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_repeat_elements, dim=0)
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].sample()
@@ -1602,11 +1602,10 @@ def main(args):
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
# Compute prior loss
prior_loss = torch.mean(
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
target_prior.shape[0], -1
),
1,

View File

@@ -146,7 +146,6 @@ _deps = [
"phonemizer",
"opencv-python",
"timm",
"flashpack",
]
# this is a lookup table with items like:
@@ -251,7 +250,6 @@ extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
extras["flashpack"] = deps_list("flashpack")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows

View File

@@ -53,5 +53,4 @@ deps = {
"phonemizer": "phonemizer",
"opencv-python": "opencv-python",
"timm": "timm",
"flashpack": "flashpack",
}

View File

@@ -540,7 +540,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_version(">=", "0.12.3"):
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)

View File

@@ -42,7 +42,6 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
CONFIG_NAME,
FLASHPACK_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
HF_ENABLE_PARALLEL_LOADING,
SAFE_WEIGHTS_INDEX_NAME,
@@ -56,7 +55,6 @@ from ..utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_flashpack_available,
is_peft_available,
is_torch_version,
logging,
@@ -675,7 +673,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant: str | None = None,
max_shard_size: int | str = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -728,12 +725,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
weights_name = WEIGHTS_NAME
if use_flashpack:
weights_name = FLASHPACK_WEIGHTS_NAME
elif safe_serialization:
weights_name = SAFETENSORS_WEIGHTS_NAME
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
@@ -760,74 +752,58 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Save the model
state_dict = model_to_save.state_dict()
if use_flashpack:
if is_flashpack_available():
import flashpack
else:
logger.error(
"Saving a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
)
raise ImportError("Please install torch and flashpack to save a FlashPack checkpoint in PyTorch.")
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
flashpack.serialization.pack_to_file(
state_dict_or_model=state_dict,
destination_path=os.path.join(save_directory, weights_name),
target_dtype=self.dtype,
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
if push_to_hub:
# Create a new empty model card and eventually tag it
@@ -964,12 +940,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is loaded from `flashpack` weights.
flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
Kwargs passed to
[`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
with `hf > auth login`. You can also activate the special >
@@ -1014,8 +984,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: dict[str, DDUFEntry] | None = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
parallel_config: ParallelConfig | ContextParallelConfig | None = kwargs.pop("parallel_config", None)
use_flashpack = kwargs.pop("use_flashpack", False)
flashpack_kwargs = kwargs.pop("flashpack_kwargs", {})
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
if is_parallel_loading_enabled and not low_cpu_mem_usage:
@@ -1244,37 +1212,30 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
subfolder=subfolder or "",
dduf_entries=dduf_entries,
)
else:
if use_flashpack:
weights_name = FLASHPACK_WEIGHTS_NAME
elif use_safetensors:
weights_name = _add_variant(SAFETENSORS_WEIGHTS_NAME, variant)
else:
weights_name = None
if weights_name is not None:
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
elif use_safetensors:
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
@@ -1314,44 +1275,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
with ContextManagers(init_contexts):
model = cls.from_config(config, **unused_kwargs)
if use_flashpack:
if is_flashpack_available():
import flashpack
else:
logger.error(
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
)
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
if device_map is None:
logger.warning(
"`device_map` has not been provided for FlashPack, model will be on `cpu` - provide `device_map` to fully utilize "
"the benefit of FlashPack."
)
flashpack_device = torch.device("cpu")
else:
device = device_map[""]
if isinstance(device, str) and device in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"FlashPack `device_map` should not be one of `auto`, `balanced`, `balanced_low_0`, `sequential`. Use a specific device instead, e.g., `device_map='cuda'` or `device_map='cuda:0'"
)
flashpack_device = torch.device(device) if not isinstance(device, torch.device) else device
flashpack.mixin.assign_from_file(
model=model,
path=resolved_model_file[0],
device=flashpack_device,
**flashpack_kwargs,
)
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if output_loading_info:
logger.warning("`output_loading_info` is not supported with FlashPack.")
return model, {}
return model
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

View File

@@ -44,7 +44,7 @@ class ErnieImageTransformer2DModelOutput(BaseOutput):
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
return out.float()
@@ -400,8 +400,8 @@ class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
]
# AdaLN
sample = self.time_proj(timestep)
sample = sample.to(dtype=dtype)
sample = self.time_proj(timestep.to(dtype))
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)

View File

@@ -445,14 +445,10 @@ class WanAnimateFaceBlockAttnProcessor:
# B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
B, T, N, C = encoder_hidden_states.shape
# Flatten T and N so the K/V projections see a 3D tensor; BnB int8 matmul only
# accepts 2D/3D inputs and would otherwise fail on this 4D activation.
encoder_hidden_states = encoder_hidden_states.flatten(1, 2) # [B, T, N, C] --> [B, T * N, C]
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
key = key.view(B, T, N, attn.heads, -1) # [B, T * N, H * D_kv] --> [B, T, N, H, D_kv]
key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
value = value.view(B, T, N, attn.heads, -1)
query = attn.norm_q(query)

View File

@@ -877,7 +877,10 @@ class FluxPipeline(
self.scheduler.config.get("max_shift", 1.15),
)
timestep_device = device
if XLA_AVAILABLE:
timestep_device = "cpu"
else:
timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,

View File

@@ -28,7 +28,6 @@ from packaging import version
from .. import __version__
from ..utils import (
FLASHPACK_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
@@ -195,7 +194,6 @@ def filter_model_files(filenames):
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
FLASHPACK_WEIGHTS_NAME,
]
if is_transformers_available():
@@ -415,9 +413,6 @@ def get_class_obj_and_candidates(
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
if class_name.startswith("FlashPack"):
class_name = class_name.removeprefix("FlashPack")
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
@@ -765,7 +760,6 @@ def load_sub_model(
provider_options: Any,
disable_mmap: bool,
quantization_config: Any | None = None,
use_flashpack: bool = False,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
from ..quantizers import PipelineQuantizationConfig
@@ -844,9 +838,6 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors
if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack
if from_flax:
loading_kwargs["from_flax"] = True
@@ -896,7 +887,7 @@ def load_sub_model(
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict) and not use_flashpack:
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
# remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu"
@@ -1102,7 +1093,6 @@ def _get_ignore_patterns(
allow_pickle: bool,
use_onnx: bool,
is_onnx: bool,
use_flashpack: bool,
variant: str | None = None,
) -> list[str]:
if (
@@ -1128,9 +1118,6 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
elif use_flashpack:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb", "*.msgpack"]
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

View File

@@ -244,7 +244,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant: str | None = None,
max_shard_size: int | str | None = None,
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -342,7 +341,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
save_method_accept_peft_format = "save_peft_format" in save_method_signature.parameters
save_kwargs = {}
@@ -353,8 +351,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size
if save_method_accept_flashpack:
save_kwargs["use_flashpack"] = use_flashpack
if save_method_accept_peft_format:
# Set save_peft_format=False for transformers>=5.0.0 compatibility
# In transformers 5.0.0+, the default save_peft_format=True adds "base_model.model" prefix
@@ -785,7 +781,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
use_flashpack = kwargs.pop("use_flashpack", False)
disable_mmap = kwargs.pop("disable_mmap", False)
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
@@ -1076,7 +1071,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
provider_options=provider_options,
disable_mmap=disable_mmap,
quantization_config=quantization_config,
use_flashpack=use_flashpack,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1582,9 +1576,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
option should only be set to `True` for repositories you trust and in which you have read the code, as
it will execute code present on the Hub on your local machine.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, FlashPack weights will always be downloaded if present. If set to `False`, FlashPack
weights will never be downloaded.
Returns:
`os.PathLike`:
@@ -1609,7 +1600,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)
dduf_file: dict[str, DDUFEntry] | None = kwargs.pop("dduf_file", None)
use_flashpack = kwargs.pop("use_flashpack", False)
if dduf_file:
if custom_pipeline:
@@ -1729,7 +1719,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
allow_pickle,
use_onnx,
pipeline_class._is_onnx,
use_flashpack,
variant,
)

View File

@@ -24,8 +24,6 @@ from .constants import (
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
DIFFUSERS_LOAD_ID_FIELDS,
FLASHPACK_FILE_EXTENSION,
FLASHPACK_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
HF_ENABLE_PARALLEL_LOADING,
@@ -78,7 +76,6 @@ from .import_utils import (
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_flashpack_available,
is_flax_available,
is_ftfy_available,
is_gguf_available,

View File

@@ -34,8 +34,6 @@ ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
SAFETENSORS_FILE_EXTENSION = "safetensors"
FLASHPACK_WEIGHTS_NAME = "model.flashpack"
FLASHPACK_FILE_EXTENSION = "flashpack"
GGUF_FILE_EXTENSION = "gguf"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")

View File

@@ -230,7 +230,6 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
_av_available, _av_version = _is_package_available("av")
@@ -362,10 +361,6 @@ def is_gguf_available():
return _gguf_available
def is_flashpack_available():
return _flashpack_available
def is_torchao_available():
return _torchao_available

View File

@@ -14,18 +14,18 @@
# limitations under the License.
import gc
import unittest
import pytest
import torch
from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
@@ -35,22 +35,30 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKL
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
@property
def output_shape(self):
return (3, 32, 32)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
return {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
@@ -59,42 +67,27 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
}
return init_dict
@property
def dummy_input(self):
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
image = model(**self.dummy_input)
image = model(**self.get_dummy_inputs())
assert image is not None, "Make sure output is not None"
@@ -168,17 +161,24 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
]
)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKL."""
class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKL."""
@slow
class AutoencoderKLIntegrationTests(unittest.TestCase):
class AutoencoderKLIntegrationTests:
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self):
def teardown_method(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@@ -341,10 +341,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -362,10 +359,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

View File

@@ -205,11 +205,6 @@ class BaseModelTesterConfig:
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
return {}
@property
def torch_dtype(self) -> torch.dtype:
"""Compute dtype used to build dummy inputs and cast inputs where needed."""
return torch.float32
@property
def output_shape(self) -> Optional[tuple]:
"""Expected output shape for output validation tests."""

View File

@@ -359,7 +359,15 @@ class QuantizationTesterMixin:
if isinstance(module, torch.nn.Linear):
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
# Get model dtype from first parameter
model_dtype = next(model.parameters()).dtype
inputs = self.get_dummy_inputs()
# Cast inputs to model dtype
inputs = {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
@@ -567,28 +575,33 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
@torch.no_grad()
def test_bnb_keep_modules_in_fp32(self):
fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
if not fp32_modules:
pytest.skip(f"{self.model_class.__name__} does not declare _keep_in_fp32_modules")
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]
model = self._create_quantized_model(config_kwargs)
model.to(torch_device)
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
self.model_class._keep_in_fp32_modules = ["proj_out"]
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in fp32_modules):
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
else:
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
try:
model = self._create_quantized_model(config_kwargs)
inputs = self.get_dummy_inputs()
_ = model(**inputs)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
else:
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
inputs = self.get_dummy_inputs()
_ = model(**inputs)
finally:
if original_fp32_modules is not None:
self.model_class._keep_in_fp32_modules = original_fp32_modules
def test_bnb_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""

View File

@@ -320,51 +320,6 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux Transformer."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"pooled_projections": randn_tensor(
(batch_size, embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"img_ids": randn_tensor(
(height * width, num_image_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
"""Quanto quantization tests for Flux Transformer."""

View File

@@ -1,74 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import tempfile
import unittest
from diffusers import AutoPipelineForText2Image
from diffusers.models.auto_model import AutoModel
from ..testing_utils import is_torch_available, require_flashpack, require_torch_gpu
if is_torch_available():
import torch
class FlashPackTests(unittest.TestCase):
model_id: str = "hf-internal-testing/tiny-flux-pipe"
@require_flashpack
def test_save_load_model(self):
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
with tempfile.TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir, use_flashpack=True)
self.assertTrue((pathlib.Path(temp_dir) / "model.flashpack").exists())
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True)
@require_flashpack
def test_save_load_pipeline(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.model_id)
with tempfile.TemporaryDirectory() as temp_dir:
pipeline.save_pretrained(temp_dir, use_flashpack=True)
self.assertTrue((pathlib.Path(temp_dir) / "transformer" / "model.flashpack").exists())
self.assertTrue((pathlib.Path(temp_dir) / "vae" / "model.flashpack").exists())
pipeline = AutoPipelineForText2Image.from_pretrained(temp_dir, use_flashpack=True)
@require_torch_gpu
@require_flashpack
def test_load_model_device_str(self):
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
with tempfile.TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir, use_flashpack=True)
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "cuda"})
self.assertTrue(model.device.type == "cuda")
@require_torch_gpu
@require_flashpack
def test_load_model_device(self):
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
with tempfile.TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir, use_flashpack=True)
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": torch.device("cuda")})
self.assertTrue(model.device.type == "cuda")
@require_flashpack
def test_load_model_device_auto(self):
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
with tempfile.TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir, use_flashpack=True)
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "auto"})

View File

@@ -34,7 +34,6 @@ from diffusers.utils.import_utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_compel_available,
is_flashpack_available,
is_flax_available,
is_gguf_available,
is_kernels_available,
@@ -738,13 +737,6 @@ def require_accelerate(test_case):
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
def require_flashpack(test_case):
"""
Decorator marking a test that requires flashpack. These tests are skipped when flashpack isn't installed.
"""
return pytest.mark.skipif(not is_flashpack_available(), reason="test requires flashpack")(test_case)
def require_peft_version_greater(peft_version):
"""
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific