Compare commits

..

3 Commits

Author SHA1 Message Date
DN6
eff791831f update 2026-03-13 10:28:38 +05:30
Dhruv Nair
07c5ba8eee [Context Parallel] Add support for custom device mesh (#13064)
* add custom mesh support

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-11 16:42:11 +05:30
Dhruv Nair
897aed72fa [Quantization] Deprecate Quanto (#13180)
* update

* update
2026-03-11 09:26:46 +05:30
7 changed files with 246 additions and 243 deletions

View File

@@ -60,6 +60,16 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -68,6 +78,7 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -124,7 +135,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh._flatten()
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -14,6 +14,7 @@
import importlib
import inspect
import os
import shutil
import sys
import traceback
import warnings
@@ -1883,6 +1884,36 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
def _maybe_save_custom_code(self, save_directory: str | os.PathLike):
"""Save custom code files (blocks config and Python modules) to the save directory."""
if self._blocks is None:
return
blocks_module = type(self._blocks).__module__
is_custom_code = not blocks_module.startswith("diffusers.") and blocks_module != "diffusers"
if not is_custom_code:
return
os.makedirs(save_directory, exist_ok=True)
self._blocks.save_pretrained(save_directory)
source_file = inspect.getfile(type(self._blocks))
module_file = os.path.basename(source_file)
dest_file = os.path.join(save_directory, module_file)
if os.path.abspath(source_file) != os.path.abspath(dest_file):
shutil.copyfile(source_file, dest_file)
from ..utils.dynamic_modules_utils import get_relative_import_files
for rel_file in get_relative_import_files(source_file):
rel_name = os.path.relpath(rel_file, os.path.dirname(source_file))
rel_dest = os.path.join(save_directory, rel_name)
if os.path.abspath(rel_file) != os.path.abspath(rel_dest):
os.makedirs(os.path.dirname(rel_dest), exist_ok=True)
shutil.copyfile(rel_file, rel_dest)
def save_pretrained(
self,
save_directory: str | os.PathLike,
@@ -1998,6 +2029,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
self._maybe_save_custom_code(save_directory)
self.save_config(save_directory=save_directory)
if push_to_hub:

View File

@@ -36,7 +36,7 @@ from typing import Any, Callable
from packaging import version
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -844,6 +844,8 @@ class QuantoConfig(QuantizationConfigMixin):
modules_to_not_convert: list[str] | None = None,
**kwargs,
):
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
deprecate("QuantoConfig", "1.0.0", deprecation_message)
self.quant_method = QuantizationMethod.QUANTO
self.weights_dtype = weights_dtype
self.modules_to_not_convert = modules_to_not_convert

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from diffusers.utils.import_utils import is_optimum_quanto_version
from ...utils import (
deprecate,
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
@@ -42,6 +43,9 @@ class QuantoQuantizer(DiffusersQuantizer):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"

View File

@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,23 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import unittest
import pytest
import torch
from diffusers import ZImageTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import assert_tensors_close, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ...testing_utils import IS_GITHUB_ACTIONS, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
@@ -42,38 +36,44 @@ if hasattr(torch.backends, "cuda"):
torch.backends.cuda.matmul.allow_tf32 = False
def _concat_list_output(output):
"""Model output `sample` is a list of tensors. Concatenate them for comparison."""
return torch.cat([t.flatten() for t in output])
@unittest.skipIf(
IS_GITHUB_ACTIONS,
reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.",
)
class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = ZImageTransformer2DModel
main_input_name = "x"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.9, 0.9, 0.9]
def prepare_dummy_input(self, height=16, width=16):
batch_size = 1
num_channels = 16
embedding_dim = 16
sequence_length = 16
class ZImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return ZImageTransformer2DModel
hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)]
encoder_hidden_states = [
torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size)
]
timestep = torch.tensor([0.0]).to(torch_device)
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
@property
def output_shape(self) -> tuple[int, ...]:
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (4, 32, 32)
@property
def input_shape(self) -> tuple[int, ...]:
def output_shape(self):
return (4, 32, 32)
@property
def model_split_percents(self) -> list:
return [0.9, 0.9, 0.9]
@property
def main_input_name(self) -> str:
return "x"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"all_patch_size": (2,),
"all_f_patch_size": (1,),
"in_channels": 16,
@@ -89,223 +89,83 @@ class ZImageTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims": [8, 4, 4],
"axes_lens": [256, 32, 32],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor | list]:
batch_size = 1
num_channels = 16
embedding_dim = 16
sequence_length = 16
height = 16
width = 16
hidden_states = [
randn_tensor((num_channels, 1, height, width), generator=self.generator, device=torch_device)
for _ in range(batch_size)
]
encoder_hidden_states = [
randn_tensor((sequence_length, embedding_dim), generator=self.generator, device=torch_device)
for _ in range(batch_size)
]
timestep = torch.tensor([0.0]).to(torch_device)
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
class TestZImageTransformer(ZImageTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for Z-Image Transformer."""
@torch.no_grad()
def test_determinism(self, atol=1e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
first = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
second = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
mask = ~(torch.isnan(first) | torch.isnan(second))
assert_tensors_close(
first[mask], second[mask], atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
)
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
def setUp(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path)
new_model.to(torch_device)
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert param_1.shape == param_2.shape
inputs_dict = self.get_dummy_inputs()
image = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
new_image = _concat_list_output(new_model(**inputs_dict, return_dict=False)[0])
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path, variant="fp16")
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmp_path)
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
inputs_dict = self.get_dummy_inputs()
image = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
new_image = _concat_list_output(new_model(**inputs_dict, return_dict=False)[0])
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@pytest.mark.skip("Model output `sample` is a list of tensors, not a single tensor.")
def test_outputs_equivalence(self, atol=1e-5, rtol=0):
pass
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, constants
from ..testing_utils.common import calculate_expected_num_shards, compute_module_persistent_sizes
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def tearDown(self):
super().tearDown()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = _concat_list_output(model(**inputs_dict, return_dict=False)[0])
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10))
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
try:
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards
constants.HF_ENABLE_PARALLEL_LOADING = False
self.model_class.from_pretrained(tmp_path).eval().to(torch_device)
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
torch.manual_seed(0)
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
model_parallel = model_parallel.to(torch_device)
output_parallel = _concat_list_output(model_parallel(**inputs_dict, return_dict=False)[0])
assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
)
finally:
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
if original_parallel_workers is not None:
constants.HF_PARALLEL_WORKERS = original_parallel_workers
class TestZImageTransformerMemory(ZImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Z-Image Transformer."""
@pytest.mark.skip(
"Ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices."
)
def test_layerwise_casting_training(self):
pass
class TestZImageTransformerTraining(ZImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Z-Image Transformer."""
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def test_gradient_checkpointing_is_applied(self):
super().test_gradient_checkpointing_is_applied(expected_set={"ZImageTransformer2DModel"})
expected_set = {"ZImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@pytest.mark.skip("Test is not supported for handling main inputs that are lists.")
@unittest.skip("Test is not supported for handling main inputs that are lists.")
def test_training(self):
pass
super().test_training()
@pytest.mark.skip("Test is not supported for handling main inputs that are lists.")
def test_training_with_ema(self):
pass
@unittest.skip("Test is not supported for handling main inputs that are lists.")
def test_ema_training(self):
super().test_ema_training()
@pytest.mark.skip("Test is not supported for handling main inputs that are lists.")
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
pass
@unittest.skip("Test is not supported for handling main inputs that are lists.")
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing()
@unittest.skip(
"Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices."
)
def test_layerwise_casting_training(self):
super().test_layerwise_casting_training()
@unittest.skip("Test is not supported for handling main inputs that are lists.")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
@unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
def test_group_offloading(self):
super().test_group_offloading()
@unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
def test_group_offloading_with_disk(self):
super().test_group_offloading_with_disk()
class TestZImageTransformerLoRA(ZImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Z-Image Transformer."""
class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = ZImageTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
@pytest.mark.skip("Model output `sample` is a list of tensors, not a single tensor.")
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4):
pass
def prepare_init_args_and_inputs_for_common(self):
return ZImageTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)
# TODO: Add pretrained_model_name_or_path once a tiny Z-Image model is available on the Hub
# class TestZImageTransformerBitsAndBytes(ZImageTransformerTesterConfig, BitsAndBytesTesterMixin):
# """BitsAndBytes quantization tests for Z-Image Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny Z-Image model is available on the Hub
# class TestZImageTransformerTorchAo(ZImageTransformerTesterConfig, TorchAoTesterMixin):
# """TorchAo quantization tests for Z-Image Transformer."""
class TestZImageTransformerCompile(ZImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Z-Image Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 16, width: int = 16) -> dict[str, torch.Tensor | list]:
batch_size = 1
num_channels = 16
embedding_dim = 16
sequence_length = 16
hidden_states = [
randn_tensor((num_channels, 1, height, width), generator=self.generator, device=torch_device)
for _ in range(batch_size)
]
encoder_hidden_states = [
randn_tensor((sequence_length, embedding_dim), generator=self.generator, device=torch_device)
for _ in range(batch_size)
]
timestep = torch.tensor([0.0]).to(torch_device)
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
@pytest.mark.skip(
"The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. The inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice."
@unittest.skip(
"The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice."
)
def test_torch_compile_recompilation_and_graph_break(self):
pass
super().test_torch_compile_recompilation_and_graph_break()
@pytest.mark.skip("Fullgraph AoT is broken")
def test_compile_works_with_aot(self, tmp_path):
pass
@unittest.skip("Fullgraph AoT is broken")
def test_compile_works_with_aot(self):
super().test_compile_works_with_aot()
@pytest.mark.skip("Fullgraph is broken")
@unittest.skip("Fullgraph is broken")
def test_compile_on_different_shapes(self):
pass
super().test_compile_on_different_shapes()