Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
10dfa9b722 up 2025-12-22 10:26:10 +05:30
4 changed files with 37 additions and 102 deletions

View File

@@ -27,7 +27,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -410,7 +410,7 @@ class HunyuanImageDecoder2D(nn.Module):
return h
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
@@ -486,27 +486,6 @@ class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin)
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape

View File

@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -584,7 +584,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
@@ -685,27 +685,6 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape

View File

@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -625,7 +625,7 @@ class HunyuanVideo15Decoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanVideo-1.5.
@@ -723,27 +723,6 @@ class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape

View File

@@ -671,46 +671,44 @@ class TorchAoSerializationTest(unittest.TestCase):
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
from torchao.quantization import Int8WeightOnlyConfig
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(Int8WeightOnlyConfig()),
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
},
)
# @unittest.skip(
# "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
# "when compiling."
# )
# def test_torch_compile_with_cpu_offload(self):
# # RuntimeError: _apply(): Couldn't swap Linear.weight
# super().test_torch_compile_with_cpu_offload()
@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight
super().test_torch_compile_with_cpu_offload()
# @parameterized.expand([False, True])
# @unittest.skip(
# """
# For `use_stream=False`:
# - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
# is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
# For `use_stream=True`:
# Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
# """
# )
# def test_torch_compile_with_group_offload_leaf(self, use_stream):
# # For use_stream=False:
# # If we run group offloading without compilation, we will see:
# # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
# # When running with compilation, the error ends up being different:
# # Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
# # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
# # Looks like something that will have to be looked into upstream.
# # for linear layers, weight.tensor_impl shows cuda... but:
# # weight.tensor_impl.{data,scale,zero_point}.device will be cpu
@parameterized.expand([False, True])
@unittest.skip(
"""
For `use_stream=False`:
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
For `use_stream=True`:
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
"""
)
def test_torch_compile_with_group_offload_leaf(self, use_stream):
# For use_stream=False:
# If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
# When running with compilation, the error ends up being different:
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
# Looks like something that will have to be looked into upstream.
# for linear layers, weight.tensor_impl shows cuda... but:
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
# # For use_stream=True:
# # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
# super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
# For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners