mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-23 21:04:56 +08:00
Compare commits
1 Commits
torchao-co
...
remove-unn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10dfa9b722 |
@@ -27,7 +27,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -410,7 +410,7 @@ class HunyuanImageDecoder2D(nn.Module):
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model for 2D images with spatial tiling support.
|
||||
|
||||
@@ -486,27 +486,6 @@ class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin)
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor):
|
||||
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -584,7 +584,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanImage-2.1 Refiner.
|
||||
@@ -685,27 +685,6 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -625,7 +625,7 @@ class HunyuanVideo15Decoder3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanVideo-1.5.
|
||||
@@ -723,27 +723,6 @@ class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
||||
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
|
||||
@@ -671,46 +671,44 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": TorchAoConfig(Int8WeightOnlyConfig()),
|
||||
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
|
||||
},
|
||||
)
|
||||
|
||||
# @unittest.skip(
|
||||
# "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
# "when compiling."
|
||||
# )
|
||||
# def test_torch_compile_with_cpu_offload(self):
|
||||
# # RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
# super().test_torch_compile_with_cpu_offload()
|
||||
@unittest.skip(
|
||||
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
"when compiling."
|
||||
)
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
super().test_torch_compile_with_cpu_offload()
|
||||
|
||||
# @parameterized.expand([False, True])
|
||||
# @unittest.skip(
|
||||
# """
|
||||
# For `use_stream=False`:
|
||||
# - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
|
||||
# is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
|
||||
# For `use_stream=True`:
|
||||
# Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
|
||||
# """
|
||||
# )
|
||||
# def test_torch_compile_with_group_offload_leaf(self, use_stream):
|
||||
# # For use_stream=False:
|
||||
# # If we run group offloading without compilation, we will see:
|
||||
# # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
|
||||
# # When running with compilation, the error ends up being different:
|
||||
# # Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
|
||||
# # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
|
||||
# # Looks like something that will have to be looked into upstream.
|
||||
# # for linear layers, weight.tensor_impl shows cuda... but:
|
||||
# # weight.tensor_impl.{data,scale,zero_point}.device will be cpu
|
||||
@parameterized.expand([False, True])
|
||||
@unittest.skip(
|
||||
"""
|
||||
For `use_stream=False`:
|
||||
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
|
||||
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
|
||||
For `use_stream=True`:
|
||||
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
|
||||
"""
|
||||
)
|
||||
def test_torch_compile_with_group_offload_leaf(self, use_stream):
|
||||
# For use_stream=False:
|
||||
# If we run group offloading without compilation, we will see:
|
||||
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
|
||||
# When running with compilation, the error ends up being different:
|
||||
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
|
||||
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
|
||||
# Looks like something that will have to be looked into upstream.
|
||||
# for linear layers, weight.tensor_impl shows cuda... but:
|
||||
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
|
||||
|
||||
# # For use_stream=True:
|
||||
# # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
|
||||
# super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
|
||||
# For use_stream=True:
|
||||
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
|
||||
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
|
||||
Reference in New Issue
Block a user