mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-14 20:45:59 +08:00
Compare commits
4 Commits
tests-load
...
z-image-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73b23dc92e | ||
|
|
c15472d2c4 | ||
|
|
d15761686a | ||
|
|
5352999e14 |
@@ -532,6 +532,8 @@
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/ddim
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
@@ -675,8 +677,6 @@
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/helios
|
||||
|
||||
@@ -21,31 +21,29 @@
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
## Basic usage
|
||||
## Loading original format checkpoints
|
||||
|
||||
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Cosmos2_5_PredictBasePipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
|
||||
|
||||
model_id = "nvidia/Cosmos-Predict2.5-2B"
|
||||
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
|
||||
model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
|
||||
)
|
||||
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
|
||||
transformer = CosmosTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor."
|
||||
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
output = pipe(
|
||||
image=None,
|
||||
video=None,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=93,
|
||||
generator=torch.Generator().manual_seed(1),
|
||||
).frames[0]
|
||||
export_to_video(output, "text2world.mp4", fps=16)
|
||||
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
).images[0]
|
||||
output.save("output.png")
|
||||
```
|
||||
|
||||
## Cosmos2_5_TransferPipeline
|
||||
|
||||
@@ -44,7 +44,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
|
||||
| [ControlNet-XS](controlnetxs) | text2image |
|
||||
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
|
||||
| [Cosmos](cosmos) | text2video, video2video |
|
||||
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
|
||||
| [DDIM](ddim) | unconditional image generation |
|
||||
| [DDPM](ddpm) | unconditional image generation |
|
||||
|
||||
@@ -60,16 +60,6 @@ class ContextParallelConfig:
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
ulysses_anything (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
|
||||
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
|
||||
`ring_degree` must be 1.
|
||||
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
|
||||
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
|
||||
creating a new one. This is useful when combining context parallelism with other parallelism strategies
|
||||
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
|
||||
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
|
||||
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
|
||||
|
||||
"""
|
||||
|
||||
@@ -78,7 +68,6 @@ class ContextParallelConfig:
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
|
||||
# Whether to enable ulysses anything attention to support
|
||||
# any sequence lengths and any head numbers.
|
||||
ulysses_anything: bool = False
|
||||
@@ -135,7 +124,7 @@ class ContextParallelConfig:
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
|
||||
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
|
||||
@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
|
||||
@@ -95,7 +95,6 @@ from .pag import (
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .prx import PRXPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
@@ -186,7 +185,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
|
||||
("z-image-omni", ZImageOmniPipeline),
|
||||
("ovis", OvisImagePipeline),
|
||||
("prx", PRXPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -82,16 +82,13 @@ EXAMPLE_DOC_STRING = """
|
||||
```python
|
||||
>>> import cv2
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
|
||||
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
|
||||
>>> controlnet = AutoModel.from_pretrained(
|
||||
... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
|
||||
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
|
||||
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
@@ -36,7 +36,7 @@ from typing import Any, Callable
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -844,8 +844,6 @@ class QuantoConfig(QuantizationConfigMixin):
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
|
||||
deprecate("QuantoConfig", "1.0.0", deprecation_message)
|
||||
self.quant_method = QuantizationMethod.QUANTO
|
||||
self.weights_dtype = weights_dtype
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any
|
||||
from diffusers.utils.import_utils import is_optimum_quanto_version
|
||||
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
get_module_from_name,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
@@ -43,9 +42,6 @@ class QuantoQuantizer(DiffusersQuantizer):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
|
||||
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
|
||||
|
||||
if not is_optimum_quanto_available():
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
||||
|
||||
@@ -60,7 +60,12 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
model.eval()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs_on_device[key] = value.to(device)
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
@@ -84,59 +89,6 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
master_port,
|
||||
model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
):
|
||||
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
|
||||
try:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# DeviceMesh must be created after init_process_group, inside each worker process.
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["output_shape"] = list(output.shape)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@@ -174,48 +126,3 @@ class ContextParallelTesterMixin:
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
|
||||
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
|
||||
],
|
||||
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
|
||||
)
|
||||
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
master_port = _find_free_port()
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
mp.spawn(
|
||||
_custom_mesh_worker,
|
||||
args=(
|
||||
world_size,
|
||||
master_port,
|
||||
self.model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,16 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import ZImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ...testing_utils import assert_tensors_close, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
||||
@@ -36,44 +42,38 @@ if hasattr(torch.backends, "cuda"):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_GITHUB_ACTIONS,
|
||||
reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.",
|
||||
)
|
||||
class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = ZImageTransformer2DModel
|
||||
main_input_name = "x"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.9, 0.9, 0.9]
|
||||
def _concat_list_output(output):
|
||||
"""Model output `sample` is a list of tensors. Concatenate them for comparison."""
|
||||
return torch.cat([t.flatten() for t in output])
|
||||
|
||||
def prepare_dummy_input(self, height=16, width=16):
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)]
|
||||
encoder_hidden_states = [
|
||||
torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size)
|
||||
]
|
||||
timestep = torch.tensor([0.0]).to(torch_device)
|
||||
|
||||
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
|
||||
class ZImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return ZImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.9, 0.9, 0.9]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "x"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"all_patch_size": (2,),
|
||||
"all_f_patch_size": (1,),
|
||||
"in_channels": 16,
|
||||
@@ -89,83 +89,223 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"axes_dims": [8, 4, 4],
|
||||
"axes_lens": [256, 32, 32],
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor | list]:
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
height = 16
|
||||
width = 16
|
||||
|
||||
hidden_states = [
|
||||
randn_tensor((num_channels, 1, height, width), generator=self.generator, device=torch_device)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
encoder_hidden_states = [
|
||||
randn_tensor((sequence_length, embedding_dim), generator=self.generator, device=torch_device)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
timestep = torch.tensor([0.0]).to(torch_device)
|
||||
|
||||
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
|
||||
|
||||
|
||||
class TestZImageTransformer(ZImageTransformerTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Z-Image Transformer."""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_determinism(self, atol=1e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
first = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
|
||||
second = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
|
||||
|
||||
mask = ~(torch.isnan(first) | torch.isnan(second))
|
||||
assert_tensors_close(
|
||||
first[mask], second[mask], atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
|
||||
)
|
||||
|
||||
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path)
|
||||
new_model.to(torch_device)
|
||||
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
assert param_1.shape == param_2.shape
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
|
||||
new_image = _concat_list_output(new_model(**inputs_dict, return_dict=False)[0])
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.save_pretrained(tmp_path, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
|
||||
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmp_path)
|
||||
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
|
||||
new_image = _concat_list_output(new_model(**inputs_dict, return_dict=False)[0])
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@pytest.mark.skip("Model output `sample` is a list of tensors, not a single tensor.")
|
||||
def test_outputs_equivalence(self, atol=1e-5, rtol=0):
|
||||
pass
|
||||
|
||||
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, constants
|
||||
|
||||
from ..testing_utils.common import calculate_expected_num_shards, compute_module_persistent_sizes
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"ZImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
base_output = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
|
||||
|
||||
@unittest.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_training(self):
|
||||
super().test_training()
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10))
|
||||
|
||||
@unittest.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_ema_training(self):
|
||||
super().test_ema_training()
|
||||
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
|
||||
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||
|
||||
@unittest.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing()
|
||||
try:
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
|
||||
@unittest.skip(
|
||||
"Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices."
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards
|
||||
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
self.model_class.from_pretrained(tmp_path).eval().to(torch_device)
|
||||
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
output_parallel = _concat_list_output(model_parallel(**inputs_dict, return_dict=False)[0])
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
|
||||
)
|
||||
finally:
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||
if original_parallel_workers is not None:
|
||||
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||
|
||||
|
||||
class TestZImageTransformerMemory(ZImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Z-Image Transformer."""
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices."
|
||||
)
|
||||
def test_layerwise_casting_training(self):
|
||||
super().test_layerwise_casting_training()
|
||||
|
||||
@unittest.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
|
||||
def test_group_offloading(self):
|
||||
super().test_group_offloading()
|
||||
|
||||
@unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
|
||||
def test_group_offloading_with_disk(self):
|
||||
super().test_group_offloading_with_disk()
|
||||
pass
|
||||
|
||||
|
||||
class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = ZImageTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
class TestZImageTransformerTraining(ZImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Z-Image Transformer."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return ZImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
super().test_gradient_checkpointing_is_applied(expected_set={"ZImageTransformer2DModel"})
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
@pytest.mark.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice."
|
||||
@pytest.mark.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_training_with_ema(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
|
||||
pass
|
||||
|
||||
|
||||
class TestZImageTransformerLoRA(ZImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Z-Image Transformer."""
|
||||
|
||||
@pytest.mark.skip("Model output `sample` is a list of tensors, not a single tensor.")
|
||||
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Add pretrained_model_name_or_path once a tiny Z-Image model is available on the Hub
|
||||
# class TestZImageTransformerBitsAndBytes(ZImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
# """BitsAndBytes quantization tests for Z-Image Transformer."""
|
||||
|
||||
|
||||
# TODO: Add pretrained_model_name_or_path once a tiny Z-Image model is available on the Hub
|
||||
# class TestZImageTransformerTorchAo(ZImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
# """TorchAo quantization tests for Z-Image Transformer."""
|
||||
|
||||
|
||||
class TestZImageTransformerCompile(ZImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Z-Image Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 16, width: int = 16) -> dict[str, torch.Tensor | list]:
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = [
|
||||
randn_tensor((num_channels, 1, height, width), generator=self.generator, device=torch_device)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
encoder_hidden_states = [
|
||||
randn_tensor((sequence_length, embedding_dim), generator=self.generator, device=torch_device)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
timestep = torch.tensor([0.0]).to(torch_device)
|
||||
|
||||
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
|
||||
|
||||
@pytest.mark.skip(
|
||||
"The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. The inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice."
|
||||
)
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
pass
|
||||
|
||||
@unittest.skip("Fullgraph AoT is broken")
|
||||
def test_compile_works_with_aot(self):
|
||||
super().test_compile_works_with_aot()
|
||||
@pytest.mark.skip("Fullgraph AoT is broken")
|
||||
def test_compile_works_with_aot(self, tmp_path):
|
||||
pass
|
||||
|
||||
@unittest.skip("Fullgraph is broken")
|
||||
@pytest.mark.skip("Fullgraph is broken")
|
||||
def test_compile_on_different_shapes(self):
|
||||
super().test_compile_on_different_shapes()
|
||||
pass
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -33,33 +32,6 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -388,39 +360,6 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
|
||||
Reference in New Issue
Block a user