mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-29 03:47:58 +08:00
Compare commits
2 Commits
cosmos-tes
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2be8bd6b3 | ||
|
|
7da22b9db5 |
3
.github/workflows/claude_review.yml
vendored
3
.github/workflows/claude_review.yml
vendored
@@ -32,6 +32,9 @@ jobs:
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
@@ -470,8 +470,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
if is_torchao_version("<=", "0.9.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
|
||||
if is_torchao_version("<", "0.15.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
|
||||
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
@@ -495,8 +495,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
"""Create configuration from a dictionary."""
|
||||
if not is_torchao_version(">", "0.9.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
|
||||
if not is_torchao_version(">=", "0.15.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ if (
|
||||
is_torch_available()
|
||||
and is_torch_version(">=", "2.6.0")
|
||||
and is_torchao_available()
|
||||
and is_torchao_version(">=", "0.7.0")
|
||||
and is_torchao_version(">=", "0.15.0")
|
||||
):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
@@ -168,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
raise ImportError(
|
||||
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
|
||||
)
|
||||
torchao_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torchao_version < version.parse("0.7.0"):
|
||||
torchao_version = version.parse(importlib.metadata.version("torchao"))
|
||||
if torchao_version < version.parse("0.15.0"):
|
||||
raise RuntimeError(
|
||||
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
)
|
||||
|
||||
self.offload = False
|
||||
|
||||
@@ -12,46 +12,60 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import CosmosTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CosmosTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return CosmosTransformer3DModel
|
||||
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CosmosTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 1, 16, 16)
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
fps = 30
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 1, 16, 16)
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
|
||||
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"attention_mask": attention_mask,
|
||||
"fps": fps,
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -66,68 +80,57 @@ class CosmosTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
|
||||
"fps": 30,
|
||||
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestCosmosTransformer(CosmosTransformerTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Cosmos Transformer."""
|
||||
|
||||
|
||||
class TestCosmosTransformerMemory(CosmosTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Cosmos Transformer."""
|
||||
|
||||
|
||||
class TestCosmosTransformerTraining(CosmosTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Cosmos Transformer."""
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CosmosTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return CosmosTransformer3DModel
|
||||
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CosmosTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 1, 16, 16)
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
fps = 30
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 1, 16, 16)
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
|
||||
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
|
||||
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"attention_mask": attention_mask,
|
||||
"fps": fps,
|
||||
"condition_mask": condition_mask,
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4 + 1,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -142,40 +145,8 @@ class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
|
||||
"fps": 30,
|
||||
"condition_mask": torch.ones(batch_size, 1, num_frames, height, width).to(torch_device),
|
||||
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestCosmosTransformerVideoToWorld(CosmosTransformerVideoToWorldTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Cosmos Transformer (Video-to-World)."""
|
||||
|
||||
|
||||
class TestCosmosTransformerVideoToWorldMemory(CosmosTransformerVideoToWorldTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Cosmos Transformer (Video-to-World)."""
|
||||
|
||||
|
||||
class TestCosmosTransformerVideoToWorldTraining(CosmosTransformerVideoToWorldTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Cosmos Transformer (Video-to-World)."""
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CosmosTransformer3DModel"}
|
||||
|
||||
@@ -14,13 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -82,18 +80,17 @@ if is_torchao_available():
|
||||
Float8WeightOnlyConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8DynamicActivationIntxWeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
IntxWeightOnlyConfig,
|
||||
)
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
|
||||
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@@ -128,7 +125,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
@@ -527,7 +524,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
def test_aobase_config(self):
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
@@ -540,7 +537,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@@ -650,7 +647,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -696,7 +693,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
@@ -854,7 +851,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user