mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-08 16:51:53 +08:00
Compare commits
1 Commits
autoencode
...
modular-pi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c3b90986a |
@@ -35,6 +35,10 @@ Strive to write code as simple and explicit as possible.
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
|
||||
|
||||
### Modular Pipelines
|
||||
|
||||
- See [modular.md](modular.md) for modular pipeline conventions, patterns, and gotchas.
|
||||
|
||||
## Skills
|
||||
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. Available skills include:
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
# Modular Pipeline Conversion Reference
|
||||
# Modular pipeline conventions and rules
|
||||
|
||||
## When to use
|
||||
|
||||
Modular pipelines break a monolithic `__call__` into composable blocks. Convert when:
|
||||
- The model supports multiple workflows (T2V, I2V, V2V, etc.)
|
||||
- Users need to swap guidance strategies (CFG, CFG-Zero*, PAG)
|
||||
- You want to share blocks across pipeline variants
|
||||
Shared reference for modular pipeline conventions, patterns, and gotchas.
|
||||
|
||||
## File structure
|
||||
|
||||
@@ -14,7 +9,7 @@ src/diffusers/modular_pipelines/<model>/
|
||||
__init__.py # Lazy imports
|
||||
modular_pipeline.py # Pipeline class (tiny, mostly config)
|
||||
encoders.py # Text encoder + image/video VAE encoder blocks
|
||||
before_denoise.py # Pre-denoise setup blocks
|
||||
before_denoise.py # Pre-denoise setup blocks (timesteps, latent prep, noise)
|
||||
denoise.py # The denoising loop blocks
|
||||
decoders.py # VAE decode block
|
||||
modular_blocks_<model>.py # Block assembly (AutoBlocks)
|
||||
@@ -81,15 +76,21 @@ for i, t in enumerate(timesteps):
|
||||
latents = components.scheduler.step(noise_pred, t, latents, generator=generator)[0]
|
||||
```
|
||||
|
||||
## Key pattern: Chunk loops for video models
|
||||
## Key pattern: Denoising loop
|
||||
|
||||
Use `LoopSequentialPipelineBlocks` for outer loop:
|
||||
All models use `LoopSequentialPipelineBlocks` for the denoising loop (iterating over timesteps):
|
||||
```python
|
||||
class ChunkDenoiseStep(LoopSequentialPipelineBlocks):
|
||||
block_classes = [PrepareChunkStep, NoiseGenStep, DenoiseInnerStep, UpdateStep]
|
||||
class MyModelDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
block_classes = [LoopBeforeDenoiser, LoopDenoiser, LoopAfterDenoiser]
|
||||
```
|
||||
|
||||
Note: blocks inside `LoopSequentialPipelineBlocks` receive `(components, block_state, k)` where `k` is the loop iteration index.
|
||||
Autoregressive video models (e.g. Helios) also use it for an outer chunk loop:
|
||||
```python
|
||||
class HeliosChunkDenoiseStep(LoopSequentialPipelineBlocks):
|
||||
block_classes = [ChunkHistorySlice, ChunkNoiseGen, ChunkDenoiseInner, ChunkUpdate]
|
||||
```
|
||||
|
||||
Note: sub-blocks inside `LoopSequentialPipelineBlocks` receive `(components, block_state, i, t)` for denoise loops or `(components, block_state, k)` for chunk loops.
|
||||
|
||||
## Key pattern: Workflow selection
|
||||
|
||||
@@ -136,6 +137,26 @@ ComponentSpec(
|
||||
)
|
||||
```
|
||||
|
||||
## Gotchas
|
||||
|
||||
1. **Importing from standard pipelines.** The modular and standard pipeline systems are parallel — modular blocks must not import from `diffusers.pipelines.*`. For shared utility methods (e.g. `_pack_latents`, `retrieve_timesteps`), either redefine as standalone functions or use `# Copied from diffusers.pipelines.<model>...` headers. See `wan/before_denoise.py` and `helios/before_denoise.py` for examples.
|
||||
|
||||
2. **Cross-importing between modular pipelines.** Don't import utilities from another model's modular pipeline (e.g. SD3 importing from `qwenimage.inputs`). If a utility is shared, move it to `modular_pipeline_utils.py` or copy it with a `# Copied from` header.
|
||||
|
||||
3. **Accepting `guidance_scale` as a pipeline input.** Users configure the guider separately (see [guider docs](https://huggingface.co/docs/diffusers/main/en/api/guiders)). Different guider types have different parameters; forwarding them through the pipeline doesn't scale. Don't manually set `components.guider.guidance_scale = ...` inside blocks. Same applies to computing `do_classifier_free_guidance` — that logic belongs in the guider.
|
||||
|
||||
4. **Accepting pre-computed outputs as inputs to skip encoding.** In standard pipelines we accept `prompt_embeds`, `negative_prompt_embeds`, `image_latents`, etc. so users can skip encoding steps. In modular pipelines this is unnecessary — users just pop out the encoder block and run it separately. Encoder blocks should only accept raw inputs (`prompt`, `image`, etc.).
|
||||
|
||||
5. **VAE encoding inside prepare-latents.** Image encoding should be its own block in `encoders.py` (e.g. `MyModelVaeEncoderStep`). The prepare-latents block should accept `image_latents`, not raw images. This lets users run encoding standalone. See `WanVaeEncoderStep` for reference.
|
||||
|
||||
6. **Instantiating components inline.** If a class like `VideoProcessor` is needed, register it as a `ComponentSpec` and access via `components.video_processor`. Don't create new instances inside block `__call__`.
|
||||
|
||||
7. **Deeply nested block structure.** Prefer flat sequences over nesting Auto blocks inside Sequential blocks inside Auto blocks. Put the `Auto` selection at the top level and make each workflow variant a flat `InsertableDict` of leaf blocks. See `flux2/modular_blocks_flux2_klein.py` for the pattern.
|
||||
|
||||
8. **Using `InputParam.template()` / `OutputParam.template()` when semantics don't match.** Templates carry predefined descriptions — e.g. the `"latents"` output template means "Denoised latents". Don't use it for initial noisy latents from a prepare-latents step. Use a plain `InputParam(...)` / `OutputParam(...)` with an accurate description instead.
|
||||
|
||||
9. **Test model paths pointing to contributor repos.** Tiny test models must live under `hf-internal-testing/`, not personal repos like `username/tiny-model`. Move the model before merge.
|
||||
|
||||
## Conversion checklist
|
||||
|
||||
- [ ] Read original pipeline's `__call__` end-to-end, map stages
|
||||
@@ -5,6 +5,7 @@ Review-specific rules for Claude. Focus on correctness — style is handled by r
|
||||
Before reviewing, read and apply the guidelines in:
|
||||
- [AGENTS.md](AGENTS.md) — coding style, copied code
|
||||
- [models.md](models.md) — model conventions, attention pattern, implementation rules, dependencies, gotchas
|
||||
- [modular.md](modular.md) — modular pipeline conventions, patterns, common mistakes
|
||||
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
|
||||
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ See [../../models.md](../../models.md) for the attention pattern, implementation
|
||||
|
||||
## Modular Pipeline Conversion
|
||||
|
||||
See [modular-conversion.md](modular-conversion.md) for the full guide on converting standard pipelines to modular format, including block types, build order, guider abstraction, and conversion checklist.
|
||||
See [modular.md](../../modular.md) for the full guide on modular pipeline conventions, block types, build order, guider abstraction, gotchas, and conversion checklist.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -13,34 +13,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderDC
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import NewAutoencoderTesterMixin
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderDC
|
||||
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderDC
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self):
|
||||
def get_autoencoder_dc_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
@@ -66,29 +56,33 @@ class AutoencoderDCTesterConfig(BaseModelTesterConfig):
|
||||
"scaling_factor": 0.41407,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_dc_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderDC."""
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference(self):
|
||||
super().test_layerwise_casting_inference()
|
||||
|
||||
|
||||
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderDC."""
|
||||
|
||||
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, NewAutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderDC."""
|
||||
|
||||
@@ -14,18 +14,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
@@ -35,30 +35,22 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import NewAutoencoderTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKL
|
||||
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
|
||||
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
return {
|
||||
init_dict = {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -67,27 +59,42 @@ class AutoencoderKLTesterConfig(BaseModelTesterConfig):
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.get_dummy_inputs())
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -161,24 +168,17 @@ class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTes
|
||||
]
|
||||
)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
|
||||
|
||||
class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKL."""
|
||||
|
||||
|
||||
class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, NewAutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKL."""
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests:
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def teardown_method(self):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -341,7 +341,10 @@ class AutoencoderKLIntegrationTests:
|
||||
|
||||
@parameterized.expand([(13,), (16,), (27,)])
|
||||
@require_torch_gpu
|
||||
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||
@@ -359,7 +362,10 @@ class AutoencoderKLIntegrationTests:
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
@@ -12,46 +12,60 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import CosmosTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CosmosTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return CosmosTransformer3DModel
|
||||
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CosmosTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 1, 16, 16)
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
fps = 30
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 1, 16, 16)
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
|
||||
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"attention_mask": attention_mask,
|
||||
"fps": fps,
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -66,68 +80,57 @@ class CosmosTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
|
||||
"fps": 30,
|
||||
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestCosmosTransformer(CosmosTransformerTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Cosmos Transformer."""
|
||||
|
||||
|
||||
class TestCosmosTransformerMemory(CosmosTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Cosmos Transformer."""
|
||||
|
||||
|
||||
class TestCosmosTransformerTraining(CosmosTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Cosmos Transformer."""
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CosmosTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return CosmosTransformer3DModel
|
||||
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CosmosTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 1, 16, 16)
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
fps = 30
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 1, 16, 16)
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
|
||||
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
|
||||
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"attention_mask": attention_mask,
|
||||
"fps": fps,
|
||||
"condition_mask": condition_mask,
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4 + 1,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -142,40 +145,8 @@ class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
|
||||
"fps": 30,
|
||||
"condition_mask": torch.ones(batch_size, 1, num_frames, height, width).to(torch_device),
|
||||
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestCosmosTransformerVideoToWorld(CosmosTransformerVideoToWorldTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Cosmos Transformer (Video-to-World)."""
|
||||
|
||||
|
||||
class TestCosmosTransformerVideoToWorldMemory(CosmosTransformerVideoToWorldTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Cosmos Transformer (Video-to-World)."""
|
||||
|
||||
|
||||
class TestCosmosTransformerVideoToWorldTraining(CosmosTransformerVideoToWorldTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Cosmos Transformer (Video-to-World)."""
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CosmosTransformer3DModel"}
|
||||
|
||||
Reference in New Issue
Block a user