Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
d957cd816d up 2025-12-22 18:39:27 +05:30
5 changed files with 45 additions and 43 deletions

View File

@@ -21,8 +21,8 @@ from transformers import (
BertModel,
BertTokenizer,
CLIPImageProcessor,
MT5Tokenizer,
T5EncoderModel,
T5Tokenizer,
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2=T5EncoderModel,
tokenizer_2=T5Tokenizer,
tokenizer_2=MT5Tokenizer,
):
super().__init__()

View File

@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -185,7 +185,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -229,7 +229,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
HunyuanDiT2DMultiControlNetModel,
],
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[T5Tokenizer] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
requires_safety_checker: bool = True,
):
super().__init__()

View File

@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -169,7 +169,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -204,7 +204,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[T5Tokenizer] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
):
super().__init__()

View File

@@ -17,7 +17,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -173,7 +173,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -208,7 +208,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
feature_extractor: Optional[CLIPImageProcessor] = None,
requires_safety_checker: bool = True,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[T5Tokenizer] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16
):
super().__init__()

View File

@@ -671,44 +671,46 @@ class TorchAoSerializationTest(unittest.TestCase):
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
from torchao.quantization import Int8WeightOnlyConfig
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
"transformer": TorchAoConfig(Int8WeightOnlyConfig()),
},
)
@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight
super().test_torch_compile_with_cpu_offload()
# @unittest.skip(
# "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
# "when compiling."
# )
# def test_torch_compile_with_cpu_offload(self):
# # RuntimeError: _apply(): Couldn't swap Linear.weight
# super().test_torch_compile_with_cpu_offload()
@parameterized.expand([False, True])
@unittest.skip(
"""
For `use_stream=False`:
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
For `use_stream=True`:
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
"""
)
def test_torch_compile_with_group_offload_leaf(self, use_stream):
# For use_stream=False:
# If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
# When running with compilation, the error ends up being different:
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
# Looks like something that will have to be looked into upstream.
# for linear layers, weight.tensor_impl shows cuda... but:
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
# @parameterized.expand([False, True])
# @unittest.skip(
# """
# For `use_stream=False`:
# - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
# is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
# For `use_stream=True`:
# Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
# """
# )
# def test_torch_compile_with_group_offload_leaf(self, use_stream):
# # For use_stream=False:
# # If we run group offloading without compilation, we will see:
# # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
# # When running with compilation, the error ends up being different:
# # Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
# # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
# # Looks like something that will have to be looked into upstream.
# # for linear layers, weight.tensor_impl shows cuda... but:
# # weight.tensor_impl.{data,scale,zero_point}.device will be cpu
# For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
# # For use_stream=True:
# # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
# super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners