mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-27 19:07:39 +08:00
Compare commits
20 Commits
glmimage-r
...
unet-model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfebba420a | ||
|
|
1fe2125802 | ||
|
|
7298f5be93 | ||
|
|
b757035df6 | ||
|
|
17257b5d68 | ||
|
|
ecfd3b4f99 | ||
|
|
5f8303fe3c | ||
|
|
99de4ceab8 | ||
|
|
c6e6992cdd | ||
|
|
ecbaed793d | ||
|
|
0411da7739 | ||
|
|
ffb254a273 | ||
|
|
ea08148bbd | ||
|
|
3a610814a3 | ||
|
|
ca4a7b0649 | ||
|
|
3371560f1d | ||
|
|
46d44b73d8 | ||
|
|
2b67fb65ef | ||
|
|
0e42a3ff93 | ||
|
|
14439ab793 |
1
.github/workflows/claude_review.yml
vendored
1
.github/workflows/claude_review.yml
vendored
@@ -10,6 +10,7 @@ permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
|
||||
@@ -41,16 +41,15 @@ The quantized CogVideoX 5B model below requires ~16GB of VRAM.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, AutoModel
|
||||
from diffusers import CogVideoXPipeline, AutoModel, TorchAoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.utils import export_to_video
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
# quantize weights to int8 with torchao
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="torchao",
|
||||
quant_kwargs={"quant_type": "int8wo"},
|
||||
components_to_quantize="transformer"
|
||||
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())}
|
||||
)
|
||||
|
||||
# fp8 layerwise weight-casting
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
[LTX-2](https://hf.co/papers/2601.03233) is a DiT-based foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
@@ -293,6 +293,7 @@ import torch
|
||||
from diffusers import LTX2ConditionPipeline
|
||||
from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
|
||||
from diffusers.utils import load_image, load_video
|
||||
|
||||
device = "cuda"
|
||||
@@ -315,19 +316,6 @@ prompt = (
|
||||
"landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the "
|
||||
"solitude and beauty of a winter drive through a mountainous region."
|
||||
)
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
|
||||
cond_video = load_video(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
|
||||
@@ -343,7 +331,7 @@ frame_rate = 24.0
|
||||
video, audio = pipe(
|
||||
conditions=conditions,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
@@ -366,6 +354,154 @@ encode_video(
|
||||
|
||||
Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static.
|
||||
|
||||
## Multimodal Guidance
|
||||
|
||||
LTX-2.X pipelines support multimodal guidance. It is composed of three terms, all using a CFG-style update rule:
|
||||
|
||||
1. Classifier-Free Guidance (CFG): standard [CFG](https://huggingface.co/papers/2207.12598) where the perturbed ("weaker") output is generated using the negative prompt.
|
||||
2. Spatio-Temporal Guidance (STG): [STG](https://huggingface.co/papers/2411.18664) moves away from a perturbed output created from short-cutting self-attention operations and substitutes in the attention values instead. The idea is that this creates sharper videos and better spatiotemporal consistency.
|
||||
3. Modality Isolation Guidance: moves away from a perturbed output created from disabling cross-modality (audio-to-video and video-to-audio) cross attention. This guidance is more specific to [LTX-2.X](https://huggingface.co/papers/2601.03233) models, with the idea that this produces better consistency between the generated audio and video.
|
||||
|
||||
These are controlled by the `guidance_scale`, `stg_scale`, and `modality_scale` arguments and can be set separately for video and audio. Additionally, for STG the transformer block indices where self-attention is skipped needs to be specified via the `spatio_temporal_guidance_blocks` argument. The LTX-2.X pipelines also support [guidance rescaling](https://huggingface.co/papers/2305.08891) to help reduce over-exposure, which can be a problem when the guidance scales are set to high values.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import LTX2ImageToVideoPipeline
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
|
||||
from diffusers.utils import load_image
|
||||
|
||||
device = "cuda"
|
||||
width = 768
|
||||
height = 512
|
||||
random_seed = 42
|
||||
frame_rate = 24.0
|
||||
generator = torch.Generator(device).manual_seed(random_seed)
|
||||
model_path = "dg845/LTX-2.3-Diffusers"
|
||||
|
||||
pipe = LTX2ImageToVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_sequential_cpu_offload(device=device)
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
|
||||
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
|
||||
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
|
||||
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
|
||||
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
|
||||
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
|
||||
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
|
||||
"breath-taking, movie-like shot."
|
||||
)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
|
||||
)
|
||||
|
||||
video, audio = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
frame_rate=frame_rate,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=3.0, # Recommended LTX-2.3 guidance parameters
|
||||
stg_scale=1.0, # Note that 0.0 (not 1.0) means that STG is disabled (all other guidance is disabled at 1.0)
|
||||
modality_scale=3.0,
|
||||
guidance_rescale=0.7,
|
||||
audio_guidance_scale=7.0, # Note that a higher CFG guidance scale is recommended for audio
|
||||
audio_stg_scale=1.0,
|
||||
audio_modality_scale=3.0,
|
||||
audio_guidance_rescale=0.7,
|
||||
spatio_temporal_guidance_blocks=[28],
|
||||
use_cross_timestep=True,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
|
||||
output_path="ltx2_3_i2v_stage_1.mp4",
|
||||
)
|
||||
```
|
||||
|
||||
## Prompt Enhancement
|
||||
|
||||
The LTX-2.X models are sensitive to prompting style. Refer to the [official prompting guide](https://ltx.io/model/model-blog/prompting-guide-for-ltx-2) for recommendations on how to write a good prompt. Using prompt enhancement, where the supplied prompts are enhanced using the pipeline's text encoder (by default a [Gemma 3](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized) model) given a system prompt, can also improve sample quality. The optional `processor` pipeline component needs to be present to use prompt enhancement. Enable prompt enhancement by supplying a `system_prompt` argument:
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import Gemma3Processor
|
||||
from diffusers import LTX2Pipeline
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT, T2V_DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
device = "cuda"
|
||||
width = 768
|
||||
height = 512
|
||||
random_seed = 42
|
||||
frame_rate = 24.0
|
||||
generator = torch.Generator(device).manual_seed(random_seed)
|
||||
model_path = "dg845/LTX-2.3-Diffusers"
|
||||
|
||||
pipe = LTX2Pipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload(device=device)
|
||||
pipe.vae.enable_tiling()
|
||||
if getattr(pipe, "processor", None) is None:
|
||||
processor = Gemma3Processor.from_pretrained("google/gemma-3-12b-it-qat-q4_0-unquantized")
|
||||
pipe.processor = processor
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
|
||||
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
|
||||
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
|
||||
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
|
||||
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
|
||||
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
|
||||
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
|
||||
"breath-taking, movie-like shot."
|
||||
)
|
||||
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
frame_rate=frame_rate,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=3.0,
|
||||
stg_scale=1.0,
|
||||
modality_scale=3.0,
|
||||
guidance_rescale=0.7,
|
||||
audio_guidance_scale=7.0,
|
||||
audio_stg_scale=1.0,
|
||||
audio_modality_scale=3.0,
|
||||
audio_guidance_rescale=0.7,
|
||||
spatio_temporal_guidance_blocks=[28],
|
||||
use_cross_timestep=True,
|
||||
system_prompt=T2V_DEFAULT_SYSTEM_PROMPT,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
|
||||
output_path="ltx2_3_t2v_stage_1.mp4",
|
||||
)
|
||||
```
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
|
||||
@@ -29,24 +29,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig("int8wo")}
|
||||
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128, version=2))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
@@ -91,18 +74,15 @@ Weight-only quantization stores the model weights in a specific low-bit data typ
|
||||
|
||||
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
|
||||
|
||||
The quantization methods supported are as follows:
|
||||
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods. An exhaustive list of configuration options are available [here](https://docs.pytorch.org/ao/main/workflows/inference.html#inference-workflows).
|
||||
|
||||
| **Category** | **Full Function Names** | **Shorthands** |
|
||||
|--------------|-------------------------|----------------|
|
||||
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
|
||||
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
|
||||
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
|
||||
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
|
||||
Some example popular quantization configurations are as follows:
|
||||
|
||||
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
|
||||
|
||||
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
|
||||
| **Category** | **Configuration Classes** |
|
||||
|---|---|
|
||||
| **Integer quantization** | [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int4WeightOnlyConfig.html), [`Int8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8WeightOnlyConfig.html), [`Int8DynamicActivationInt8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8DynamicActivationInt8WeightConfig.html) |
|
||||
| **Floating point 8-bit quantization** | [`Float8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8WeightOnlyConfig.html), [`Float8DynamicActivationFloat8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8DynamicActivationFloat8WeightConfig.html) |
|
||||
| **Unsigned integer quantization** | [`IntxWeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.IntxWeightOnlyConfig.html) |
|
||||
|
||||
## Serializing and Deserializing quantized models
|
||||
|
||||
@@ -111,8 +91,9 @@ To serialize a quantized model in a given dtype, first load the model with the d
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, TorchAoConfig
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
@@ -137,18 +118,19 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4` weight-only, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
|
||||
from torchao.quantization import IntxWeightOnlyConfig
|
||||
|
||||
# Serialize the model
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=TorchAoConfig("uint4wo"),
|
||||
quantization_config=TorchAoConfig(IntxWeightOnlyConfig(dtype=torch.uint4)),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
|
||||
|
||||
@@ -1,6 +1,155 @@
|
||||
# Copyright 2026 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Pre-trained sigma values for distilled model are taken from
|
||||
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
|
||||
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
|
||||
|
||||
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]
|
||||
|
||||
|
||||
# Default negative prompt from
|
||||
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py#L131-L143
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
|
||||
|
||||
# System prompts for prompt enhancement
|
||||
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt#L1
|
||||
# Disable line-too-long rule in ruff to keep the prompts exactly the same (e.g. in terms of newlines)
|
||||
# Supported in ruff>=0.15.0
|
||||
# ruff: disable[E501]
|
||||
T2V_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed
|
||||
video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||
|
||||
#### Guidelines
|
||||
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions,
|
||||
actions, camera movement, audio).
|
||||
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
|
||||
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
|
||||
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
|
||||
movements.
|
||||
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested).
|
||||
Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g.,
|
||||
"ambient sound is present").
|
||||
- Speech (only when requested):
|
||||
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with
|
||||
voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
|
||||
- Specify language if not English and accent if relevant.
|
||||
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if
|
||||
unspecified. Omit if unclear.
|
||||
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
|
||||
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
|
||||
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
|
||||
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
|
||||
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
|
||||
|
||||
#### Important notes:
|
||||
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is
|
||||
requested.
|
||||
- Camera motion: DO NOT invent camera motion unless requested by the user.
|
||||
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological
|
||||
scene description.
|
||||
- Format: DO NOT start your response with special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits
|
||||
or introduce new elements. Add/enhance audio descriptions if missing.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single continuous paragraph in natural language (English).
|
||||
- NO titles, headings, prefaces, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video
|
||||
generation.
|
||||
|
||||
#### Example Input: "A woman at a coffee shop talking on the phone" Output: Style: realistic with cinematic lighting.
|
||||
In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the
|
||||
window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone
|
||||
to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle
|
||||
clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a
|
||||
soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd
|
||||
love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her
|
||||
chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully,
|
||||
lowering the phone.
|
||||
"""
|
||||
# ruff: enable[E501]
|
||||
|
||||
# ruff: disable[E501]
|
||||
I2V_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and
|
||||
user Raw Input Prompt, generate a prompt to guide video generation from that image.
|
||||
|
||||
#### Guidelines:
|
||||
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
|
||||
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in
|
||||
conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image
|
||||
to user's scene).
|
||||
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause
|
||||
scene cuts.
|
||||
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
|
||||
movements.
|
||||
- Chronological flow: Use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio
|
||||
intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when
|
||||
requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
|
||||
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The
|
||||
tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation
|
||||
mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should
|
||||
include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just
|
||||
saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room
|
||||
underscores his animated speech.")
|
||||
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
|
||||
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
|
||||
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
|
||||
|
||||
#### Important notes:
|
||||
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion
|
||||
only if specified in the input.
|
||||
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
|
||||
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style
|
||||
(optional) and chronological scene description.
|
||||
- Format: Never start output with punctuation marks or special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio
|
||||
descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
#### Example output: Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a
|
||||
cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the
|
||||
counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine
|
||||
hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
|
||||
"""
|
||||
# ruff: enable[E501]
|
||||
|
||||
@@ -23,20 +23,17 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
from ..utils import deprecate, is_torch_available, is_torchao_version, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -53,16 +50,6 @@ class QuantizationMethod(str, Enum):
|
||||
MODELOPT = "modelopt"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization.quant_primitives import MappingType
|
||||
|
||||
class TorchAoJSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, MappingType):
|
||||
return obj.name
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationConfigMixin:
|
||||
"""
|
||||
@@ -446,49 +433,21 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
"""This is a config class for torchao quantization/sparsity techniques.
|
||||
|
||||
Args:
|
||||
quant_type (`str` | AOBaseConfig):
|
||||
The type of quantization we want to use, currently supporting:
|
||||
- **Integer quantization:**
|
||||
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
|
||||
`int8_weight_only`, `int8_dynamic_activation_int8_weight`
|
||||
- Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
|
||||
|
||||
- **Floating point 8-bit quantization:**
|
||||
- Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`,
|
||||
`float8_static_activation_float8_weight`
|
||||
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
|
||||
`float8_e4m3_tensor`, `float8_e4m3_row`,
|
||||
|
||||
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
|
||||
- Full function names: `fpx_weight_only`
|
||||
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
|
||||
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
|
||||
be satisfied for a given shorthand notation.
|
||||
|
||||
- **Unsigned Integer quantization:**
|
||||
- Full function names: `uintx_weight_only`
|
||||
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
|
||||
- An AOBaseConfig instance: for more advanced configuration options.
|
||||
quant_type (`AOBaseConfig`):
|
||||
An `AOBaseConfig` subclass instance specifying the quantization type. See the [torchao
|
||||
documentation](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) for
|
||||
available config classes (e.g. `Int4WeightOnlyConfig`, `Int8WeightOnlyConfig`, `Float8WeightOnlyConfig`,
|
||||
`Float8DynamicActivationFloat8WeightConfig`, etc.).
|
||||
modules_to_not_convert (`list[str]`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
||||
modules left in their original precision.
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
|
||||
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
|
||||
documentation of arguments can be found in
|
||||
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
|
||||
|
||||
Example:
|
||||
```python
|
||||
from diffusers import FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
# AOBaseConfig-based configuration
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
|
||||
# String-based config
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
@@ -500,7 +459,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_type: str | "AOBaseConfig", # noqa: F821
|
||||
quant_type: "AOBaseConfig", # noqa: F821
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -508,89 +467,28 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.quant_type = quant_type
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
# When we load from serialized config, "quant_type_kwargs" will be the key
|
||||
if "quant_type_kwargs" in kwargs:
|
||||
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
|
||||
else:
|
||||
self.quant_type_kwargs = kwargs
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
if not isinstance(self.quant_type, str):
|
||||
if is_torchao_version("<=", "0.9.0"):
|
||||
raise ValueError(
|
||||
f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
|
||||
f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
|
||||
)
|
||||
if is_torchao_version("<=", "0.9.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
|
||||
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
if not isinstance(self.quant_type, AOBaseConfig):
|
||||
raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
|
||||
|
||||
elif isinstance(self.quant_type, str):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
|
||||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
||||
is_floatx_quant_type = self.quant_type.startswith("fp")
|
||||
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
|
||||
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
||||
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
||||
)
|
||||
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
|
||||
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
|
||||
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
|
||||
signature = inspect.signature(method)
|
||||
all_kwargs = {
|
||||
param.name
|
||||
for param in signature.parameters.values()
|
||||
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
|
||||
}
|
||||
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
|
||||
|
||||
if len(unsupported_kwargs) > 0:
|
||||
raise ValueError(
|
||||
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
|
||||
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
|
||||
)
|
||||
if not isinstance(self.quant_type, AOBaseConfig):
|
||||
raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}")
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert configuration to a dictionary."""
|
||||
d = super().to_dict()
|
||||
|
||||
if isinstance(self.quant_type, str):
|
||||
# Handle layout serialization if present
|
||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||
if is_dataclass(d["quant_type_kwargs"]["layout"]):
|
||||
d["quant_type_kwargs"]["layout"] = [
|
||||
d["quant_type_kwargs"]["layout"].__class__.__name__,
|
||||
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
|
||||
]
|
||||
if isinstance(d["quant_type_kwargs"]["layout"], list):
|
||||
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
|
||||
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
|
||||
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
|
||||
else:
|
||||
raise ValueError("layout must be a list")
|
||||
else:
|
||||
# Handle AOBaseConfig serialization
|
||||
from torchao.core.config import config_to_dict
|
||||
# Handle AOBaseConfig serialization
|
||||
from torchao.core.config import config_to_dict
|
||||
|
||||
# For now we assume there is 1 config per Transformer, however in the future
|
||||
# We may want to support a config per fqn.
|
||||
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
|
||||
# For now we assume there is 1 config per Transformer, however in the future
|
||||
# we may want to support a config per fqn.
|
||||
# See: https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.quantize_.html
|
||||
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
|
||||
|
||||
return d
|
||||
|
||||
@@ -602,8 +500,6 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
|
||||
if isinstance(quant_type, str):
|
||||
return cls(quant_type=quant_type, **config_dict)
|
||||
# Check if we only have one key which is "default"
|
||||
# In the future we may update this
|
||||
assert len(quant_type) == 1 and "default" in quant_type, (
|
||||
@@ -618,210 +514,13 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
|
||||
return cls(quant_type=quant_type, **config_dict)
|
||||
|
||||
@classmethod
|
||||
def _get_torchao_quant_type_to_method(cls):
|
||||
r"""
|
||||
Returns supported torchao quantization types with all commonly used notations.
|
||||
"""
|
||||
|
||||
if is_torchao_available():
|
||||
# TODO(aryan): Support sparsify
|
||||
from torchao.quantization import (
|
||||
float8_dynamic_activation_float8_weight,
|
||||
float8_static_activation_float8_weight,
|
||||
float8_weight_only,
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int4_weight,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
uintx_weight_only,
|
||||
)
|
||||
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
from torchao.quantization import fpx_weight_only
|
||||
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
|
||||
from torchao.quantization.observer import PerRow, PerTensor
|
||||
|
||||
def generate_float8dq_types(dtype: torch.dtype):
|
||||
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
|
||||
types = {}
|
||||
|
||||
for granularity_cls in [PerTensor, PerRow]:
|
||||
# Note: Activation and Weights cannot have different granularities
|
||||
granularity_name = "tensor" if granularity_cls is PerTensor else "row"
|
||||
types[f"float8dq_{name}_{granularity_name}"] = partial(
|
||||
float8_dynamic_activation_float8_weight,
|
||||
activation_dtype=dtype,
|
||||
weight_dtype=dtype,
|
||||
granularity=(granularity_cls(), granularity_cls()),
|
||||
)
|
||||
|
||||
return types
|
||||
|
||||
def generate_fpx_quantization_types(bits: int):
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
types = {}
|
||||
|
||||
for ebits in range(1, bits):
|
||||
mbits = bits - ebits - 1
|
||||
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
|
||||
|
||||
non_sign_bits = bits - 1
|
||||
default_ebits = (non_sign_bits + 1) // 2
|
||||
default_mbits = non_sign_bits - default_ebits
|
||||
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
|
||||
|
||||
return types
|
||||
else:
|
||||
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
|
||||
|
||||
INT4_QUANTIZATION_TYPES = {
|
||||
# int4 weight + bfloat16/float16 activation
|
||||
"int4wo": int4_weight_only,
|
||||
"int4_weight_only": int4_weight_only,
|
||||
# int4 weight + int8 activation
|
||||
"int4dq": int8_dynamic_activation_int4_weight,
|
||||
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
|
||||
}
|
||||
|
||||
INT8_QUANTIZATION_TYPES = {
|
||||
# int8 weight + bfloat16/float16 activation
|
||||
"int8wo": int8_weight_only,
|
||||
"int8_weight_only": int8_weight_only,
|
||||
# int8 weight + int8 activation
|
||||
"int8dq": int8_dynamic_activation_int8_weight,
|
||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||
}
|
||||
|
||||
# TODO(aryan): handle torch 2.2/2.3
|
||||
FLOATX_QUANTIZATION_TYPES = {
|
||||
# float8_e5m2 weight + bfloat16/float16 activation
|
||||
"float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
|
||||
"float8_weight_only": float8_weight_only,
|
||||
"float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
|
||||
# float8_e4m3 weight + bfloat16/float16 activation
|
||||
"float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
|
||||
# float8_e5m2 weight + float8 activation (dynamic)
|
||||
"float8dq": float8_dynamic_activation_float8_weight,
|
||||
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
|
||||
# ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out.
|
||||
# However, changing activation_dtype=torch.float8_e4m3 might work here =====
|
||||
# "float8dq_e5m2": partial(
|
||||
# float8_dynamic_activation_float8_weight,
|
||||
# activation_dtype=torch.float8_e5m2,
|
||||
# weight_dtype=torch.float8_e5m2,
|
||||
# ),
|
||||
# **generate_float8dq_types(torch.float8_e5m2),
|
||||
# ===== =====
|
||||
# float8_e4m3 weight + float8 activation (dynamic)
|
||||
"float8dq_e4m3": partial(
|
||||
float8_dynamic_activation_float8_weight,
|
||||
activation_dtype=torch.float8_e4m3fn,
|
||||
weight_dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
**generate_float8dq_types(torch.float8_e4m3fn),
|
||||
# float8 weight + float8 activation (static)
|
||||
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
|
||||
}
|
||||
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
|
||||
|
||||
UINTX_QUANTIZATION_DTYPES = {
|
||||
"uintx_weight_only": uintx_weight_only,
|
||||
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
|
||||
"uint2wo": partial(uintx_weight_only, dtype=torch.uint2),
|
||||
"uint3wo": partial(uintx_weight_only, dtype=torch.uint3),
|
||||
"uint4wo": partial(uintx_weight_only, dtype=torch.uint4),
|
||||
"uint5wo": partial(uintx_weight_only, dtype=torch.uint5),
|
||||
"uint6wo": partial(uintx_weight_only, dtype=torch.uint6),
|
||||
"uint7wo": partial(uintx_weight_only, dtype=torch.uint7),
|
||||
# "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
|
||||
}
|
||||
|
||||
QUANTIZATION_TYPES = {}
|
||||
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
|
||||
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
|
||||
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
|
||||
|
||||
if cls._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
|
||||
|
||||
return QUANTIZATION_TYPES
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
return minor >= 9
|
||||
return major >= 9
|
||||
elif torch.xpu.is_available():
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
"""Create the appropriate quantization method based on configuration."""
|
||||
if not isinstance(self.quant_type, str):
|
||||
return self.quant_type
|
||||
else:
|
||||
methods = self._get_torchao_quant_type_to_method()
|
||||
quant_type_kwargs = self.quant_type_kwargs.copy()
|
||||
if (
|
||||
not torch.cuda.is_available()
|
||||
and is_torchao_available()
|
||||
and self.quant_type == "int4_weight_only"
|
||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
and quant_type_kwargs.get("layout", None) is None
|
||||
):
|
||||
if torch.xpu.is_available():
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
|
||||
"0.11.0"
|
||||
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
|
||||
from torchao.dtypes import Int4XPULayout
|
||||
from torchao.quantization.quant_primitives import ZeroPointDomain
|
||||
|
||||
quant_type_kwargs["layout"] = Int4XPULayout()
|
||||
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
|
||||
)
|
||||
else:
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
quant_type_kwargs["layout"] = Int4CPULayout()
|
||||
|
||||
return methods[self.quant_type](**quant_type_kwargs)
|
||||
return self.quant_type
|
||||
|
||||
def __repr__(self):
|
||||
r"""
|
||||
Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
|
||||
|
||||
```
|
||||
TorchAoConfig {
|
||||
"modules_to_not_convert": null,
|
||||
"quant_method": "torchao",
|
||||
"quant_type": "uint4wo",
|
||||
"quant_type_kwargs": {
|
||||
"group_size": 32
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
return (
|
||||
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
|
||||
)
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -20,7 +20,6 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
|
||||
import importlib
|
||||
import re
|
||||
import types
|
||||
from fnmatch import fnmatch
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from packaging import version
|
||||
@@ -199,13 +198,13 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
quant_type = self.quantization_config.quant_type
|
||||
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
|
||||
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning(
|
||||
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
|
||||
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
|
||||
)
|
||||
config_name = self.quantization_config.quant_type.__class__.__name__
|
||||
is_int_quant = config_name.startswith("Int") or config_name.startswith("Uint")
|
||||
if is_int_quant and torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning(
|
||||
f"You are trying to set torch_dtype to {torch_dtype} for integer quantization, but "
|
||||
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
|
||||
)
|
||||
|
||||
if torch_dtype is None:
|
||||
# We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
|
||||
@@ -219,45 +218,16 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
return torch_dtype
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
quant_type = self.quantization_config.quant_type
|
||||
from accelerate.utils import CustomDtype
|
||||
|
||||
if isinstance(quant_type, str):
|
||||
if quant_type.startswith("int8"):
|
||||
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
|
||||
return torch.int8
|
||||
elif quant_type.startswith("int4"):
|
||||
return CustomDtype.INT4
|
||||
elif quant_type == "uintx_weight_only":
|
||||
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
|
||||
elif quant_type.startswith("uint"):
|
||||
return {
|
||||
1: torch.uint1,
|
||||
2: torch.uint2,
|
||||
3: torch.uint3,
|
||||
4: torch.uint4,
|
||||
5: torch.uint5,
|
||||
6: torch.uint6,
|
||||
7: torch.uint7,
|
||||
}[int(quant_type[4])]
|
||||
elif quant_type.startswith("float") or quant_type.startswith("fp"):
|
||||
return torch.bfloat16
|
||||
quant_type = self.quantization_config.quant_type
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
elif is_torchao_version(">", "0.9.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
# Map the extracted digit to appropriate dtype
|
||||
if size_digit == "4":
|
||||
return CustomDtype.INT4
|
||||
else:
|
||||
# Default to int8
|
||||
return torch.int8
|
||||
if size_digit == "4":
|
||||
return CustomDtype.INT4
|
||||
else:
|
||||
return torch.int8
|
||||
|
||||
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
|
||||
return target_dtype
|
||||
@@ -337,29 +307,14 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
- Use a division factor of 8 for int4 weights
|
||||
- Use a division factor of 4 for int8 weights
|
||||
"""
|
||||
# Original mapping for non-AOBaseConfig types
|
||||
# For the uint types, this is a best guess. Once these types become more used
|
||||
# we can look into their nuances.
|
||||
if is_torchao_version(">", "0.9.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
if size_digit == "4":
|
||||
return 8
|
||||
else:
|
||||
return 4
|
||||
|
||||
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
|
||||
quant_type = self.quantization_config.quant_type
|
||||
for pattern, target_dtype in map_to_target_dtype.items():
|
||||
if fnmatch(quant_type, pattern):
|
||||
return target_dtype
|
||||
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
if size_digit == "4":
|
||||
return 8
|
||||
else:
|
||||
return 4
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
@@ -415,9 +370,17 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
|
||||
return _is_torchao_serializable
|
||||
|
||||
_TRAINABLE_QUANTIZATION_CONFIGS = (
|
||||
"Int8WeightOnlyConfig",
|
||||
"Int8DynamicActivationInt8WeightConfig",
|
||||
"Int8StaticActivationInt8WeightConfig",
|
||||
"Float8WeightOnlyConfig",
|
||||
"Float8DynamicActivationFloat8WeightConfig",
|
||||
)
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return self.quantization_config.quant_type.startswith("int8")
|
||||
return self.quantization_config.quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS
|
||||
|
||||
@property
|
||||
def is_compileable(self) -> bool:
|
||||
|
||||
@@ -465,7 +465,8 @@ class UNetTesterMixin:
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
@@ -480,9 +481,9 @@ class UNetTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
|
||||
@@ -287,8 +287,9 @@ class ModelTesterMixin:
|
||||
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
)
|
||||
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = model(**inputs_dict, return_dict=False)[0]
|
||||
new_image = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -308,8 +309,9 @@ class ModelTesterMixin:
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = model(**inputs_dict, return_dict=False)[0]
|
||||
new_image = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -337,8 +339,9 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
first = model(**inputs_dict, return_dict=False)[0]
|
||||
second = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
@@ -395,8 +398,9 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@@ -523,8 +527,10 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
|
||||
@@ -563,8 +569,10 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
|
||||
@@ -614,8 +622,10 @@ class ModelTesterMixin:
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
|
||||
|
||||
@@ -92,9 +92,6 @@ class TorchCompileTesterMixin:
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
|
||||
@@ -25,7 +25,6 @@ from diffusers.utils.import_utils import (
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -63,8 +62,7 @@ if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
import torchao.quantization as _torchao_quantization
|
||||
|
||||
|
||||
class LoRALayer(torch.nn.Module):
|
||||
@@ -806,9 +804,9 @@ class TorchAoConfigMixin:
|
||||
"""
|
||||
|
||||
TORCHAO_QUANT_TYPES = {
|
||||
"int4wo": {"quant_type": "int4_weight_only"},
|
||||
"int8wo": {"quant_type": "int8_weight_only"},
|
||||
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
|
||||
"int4wo": "Int4WeightOnlyConfig",
|
||||
"int8wo": "Int8WeightOnlyConfig",
|
||||
"int8dq": "Int8DynamicActivationInt8WeightConfig",
|
||||
}
|
||||
|
||||
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
@@ -817,8 +815,13 @@ class TorchAoConfigMixin:
|
||||
"int8dq": 1.5,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = TorchAoConfig(**config_kwargs)
|
||||
@staticmethod
|
||||
def _get_quant_config(config_name):
|
||||
config_cls = getattr(_torchao_quantization, config_name)
|
||||
return TorchAoConfig(config_cls())
|
||||
|
||||
def _create_quantized_model(self, config_name, **extra_kwargs):
|
||||
config = self._get_quant_config(config_name)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs["device_map"] = str(torch_device)
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import GlmImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class GlmImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return GlmImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (16, 8, 8)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 2,
|
||||
"text_embed_dim": 32,
|
||||
"time_embed_dim": 16,
|
||||
"condition_dim": 8,
|
||||
"prior_vq_quantizer_codebook_size": 64,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
height = width = 8
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, 32), generator=self.generator, device=torch_device
|
||||
),
|
||||
"prior_token_id": torch.randint(0, 64, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"prior_token_drop": torch.zeros(batch_size, dtype=torch.bool, device=torch_device),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"target_size": torch.tensor([[height, width]] * batch_size, dtype=torch.float32).to(torch_device),
|
||||
"crop_coords": torch.tensor([[0, 0]] * batch_size, dtype=torch.float32).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestGlmImageTransformer(GlmImageTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestGlmImageTransformerTraining(GlmImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"GlmImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestGlmImageTransformerCompile(GlmImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
@@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -26,64 +24,39 @@ from ...testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
_LAYERWISE_CASTING_XFAIL_REASON = (
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
)
|
||||
|
||||
|
||||
class UNet1DTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet1DModel testing (standard variant)."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
def model_class(self):
|
||||
return UNet1DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 14, 16)
|
||||
return (14, 16)
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
def test_output(self):
|
||||
super().test_output()
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": (8, 8, 16, 16),
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
@@ -97,18 +70,40 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
|
||||
"act_fn": "swish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
|
||||
"timestep": torch.tensor([10] * batch_size).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Not implemented yet for this UNet")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestUNet1DHubLoading(UNet1DTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
image = model(**self.get_dummy_inputs())
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -131,12 +126,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
|
||||
|
||||
@slow
|
||||
def test_unet_1d_maestro(self):
|
||||
@@ -157,98 +147,29 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
assert (output_sum - 224.0896).abs() < 0.5
|
||||
assert (output_max - 0.0607).abs() < 4e-4
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
super().test_layerwise_casting_inference()
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
# =============================================================================
|
||||
# UNet1D RL (Value Function) Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
class UNet1DRLTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet1DModel testing (RL value function variant)."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
def model_class(self):
|
||||
return UNet1DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 14, 1)
|
||||
return (1,)
|
||||
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function is different output shape
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
|
||||
@@ -264,18 +185,54 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"time_embedding_type": "positional",
|
||||
"act_fn": "mish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
|
||||
"timestep": torch.tensor([10] * batch_size).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Not implemented yet for this UNet")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function with different output shape (batch, 1)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert output is not None
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
)
|
||||
self.assertIsNotNone(value_function)
|
||||
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
|
||||
assert value_function is not None
|
||||
assert len(vf_loading_info["missing_keys"]) == 0
|
||||
|
||||
value_function.to(torch_device)
|
||||
image = value_function(**self.dummy_input)
|
||||
image = value_function(**self.get_dummy_inputs())
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -299,31 +256,4 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([165.25] * seq_len)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
assert torch.allclose(output, expected_output_slice, rtol=1e-3)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
|
||||
import gc
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
@@ -31,39 +30,40 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
# =============================================================================
|
||||
# Standard UNet2D Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNet2DTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for standard UNet2DModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 2,
|
||||
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
|
||||
@@ -74,11 +74,22 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
def test_mid_block_attn_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["add_attention"] = True
|
||||
init_dict["attn_norm_num_groups"] = 4
|
||||
@@ -87,13 +98,11 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
self.assertIsNotNone(
|
||||
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
|
||||
assert model.mid_block.attentions[0].group_norm is not None, (
|
||||
"Mid block Attention group norm should exist but does not."
|
||||
)
|
||||
self.assertEqual(
|
||||
model.mid_block.attentions[0].group_norm.num_groups,
|
||||
init_dict["attn_norm_num_groups"],
|
||||
"Mid block Attention group norm does not have the expected number of groups.",
|
||||
assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], (
|
||||
"Mid block Attention group norm does not have the expected number of groups."
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -102,13 +111,15 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_mid_block_none(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
mid_none_init_dict = self.get_init_dict()
|
||||
mid_none_inputs_dict = self.get_dummy_inputs()
|
||||
mid_none_init_dict["mid_block_type"] = None
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -119,7 +130,7 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
mid_none_model.to(torch_device)
|
||||
mid_none_model.eval()
|
||||
|
||||
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
|
||||
assert mid_none_model.mid_block is None, "Mid block should not exist."
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
@@ -133,8 +144,10 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
if isinstance(mid_none_output, dict):
|
||||
mid_none_output = mid_none_output.to_tuple()[0]
|
||||
|
||||
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
|
||||
assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different."
|
||||
|
||||
|
||||
class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"AttnUpBlock2D",
|
||||
@@ -143,41 +156,32 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"UpBlock2D",
|
||||
"DownBlock2D",
|
||||
}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 8
|
||||
block_out_channels = (16, 32)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
# =============================================================================
|
||||
# UNet2D LDM Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNet2DLDMTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DModel LDM variant testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"sample_size": 32,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
@@ -187,17 +191,34 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"down_block_types": ("DownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "UpBlock2D"),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input).sample
|
||||
image = model(**self.get_dummy_inputs()).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -205,7 +226,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_from_pretrained_accelerate(self):
|
||||
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input).sample
|
||||
image = model(**self.get_dummy_inputs()).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -265,44 +286,31 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 32
|
||||
block_out_channels = (32, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3)
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
# =============================================================================
|
||||
# NCSN++ Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class NCSNppTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DModel NCSN++ variant testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self, sizes=(32, 32)):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": [32, 64, 64, 64],
|
||||
"in_channels": 3,
|
||||
"layers_per_block": 1,
|
||||
@@ -324,17 +332,71 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"SkipUpBlock2D",
|
||||
],
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_keep_in_fp32_modules(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_from_save_pretrained_dtype_inference(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestNCSNppHubLoading(NCSNppTesterConfig):
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
|
||||
model.to(torch_device)
|
||||
inputs = self.dummy_input
|
||||
inputs = self.get_dummy_inputs()
|
||||
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
|
||||
inputs["sample"] = noise
|
||||
image = model(**inputs)
|
||||
@@ -361,7 +423,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
|
||||
@@ -382,35 +444,4 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# not required for this model
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
|
||||
block_out_channels = (32, 64, 64, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
|
||||
|
||||
@unittest.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
|
||||
@@ -20,6 +20,7 @@ import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from parameterized import parameterized
|
||||
@@ -52,17 +53,24 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import (
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
UNetTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from ..testing_utils.lora import check_if_lora_correctly_set
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -82,16 +90,6 @@ def get_unet_lora_config():
|
||||
return unet_lora_config
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
@@ -354,34 +352,28 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
# We override the items here because the unet under consideration is small.
|
||||
model_split_percents = [0.5, 0.34, 0.4]
|
||||
class UNet2DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DConditionModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
def model_class(self):
|
||||
return UNet2DConditionModel
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int, int]:
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
def model_split_percents(self) -> list[float]:
|
||||
return [0.5, 0.34, 0.4]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
"""Return UNet2D model initialization arguments."""
|
||||
return {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
@@ -393,26 +385,24 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
"""Return dummy inputs for UNet2D model."""
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
}
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -427,12 +417,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
@@ -446,12 +437,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["cross_attention_dim"] = (8, 8)
|
||||
|
||||
@@ -465,12 +457,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_simple_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
@@ -489,12 +482,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
@@ -514,12 +508,287 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty small,
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
|
||||
def test_pickle(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
def test_asymmetrical_unet(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# Add asymmetry to configs
|
||||
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
|
||||
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
output = model(**inputs_dict).sample
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# Check if input and output shapes are the same
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig):
|
||||
"""Hub checkpoint loading tests for UNet2DConditionModel."""
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
|
||||
class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for UNet2DConditionModel."""
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
"""Test that deprecated load_attn_procs method raises FutureWarning."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"):
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
"""Test that deprecated save_attn_procs method raises FutureWarning."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"):
|
||||
model.save_attn_procs(os.path.join(tmpdirname))
|
||||
|
||||
|
||||
class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for UNet2DConditionModel."""
|
||||
|
||||
|
||||
class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for UNet2DConditionModel."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for UNet2DConditionModel."""
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -544,7 +813,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert output is not None
|
||||
|
||||
def test_model_sliceable_head_dim(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -562,21 +831,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
for module in model.children():
|
||||
check_sliceable_dim_attr(module)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
attention_head_dim = (8, 16)
|
||||
block_out_channels = (16, 32)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_special_attn_proc(self):
|
||||
class AttnEasyProc(torch.nn.Module):
|
||||
def __init__(self, num):
|
||||
@@ -618,7 +872,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
return hidden_states
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -645,7 +900,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
]
|
||||
)
|
||||
def test_model_xattn_mask(self, mask_dtype):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
|
||||
model.to(torch_device)
|
||||
@@ -675,39 +931,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
|
||||
)
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
"""Custom Diffusion processor tests for UNet2DConditionModel."""
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -733,8 +963,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert (sample1 - sample2).abs().max() < 3e-3
|
||||
|
||||
def test_custom_diffusion_save_load(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -754,7 +984,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname, safe_serialization=False)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
@@ -773,8 +1003,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_custom_diffusion_xformers_on_off(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -798,41 +1028,28 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
def test_pickle(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for UNet2DConditionModel."""
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
return create_ip_adapter_state_dict(model)
|
||||
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
def test_asymmetrical_unet(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
# Add asymmetry to configs
|
||||
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
|
||||
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
output = model(**inputs_dict).sample
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# Check if input and output shapes are the same
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
|
||||
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", 8)
|
||||
image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
|
||||
return inputs_dict
|
||||
|
||||
def test_ip_adapter(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -905,7 +1122,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_ip_adapter_plus(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -977,185 +1195,16 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for UNet2DConditionModel."""
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
|
||||
|
||||
|
||||
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for UNet2DConditionModel."""
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -18,47 +18,44 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models import ModelMixin, UNet3DConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers import UNet3DConditionModel
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet3DConditionModel
|
||||
main_input_name = "sample"
|
||||
class UNet3DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet3DConditionModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 4, 16, 16)
|
||||
def model_class(self):
|
||||
return UNet3DConditionModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 4, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": (
|
||||
@@ -73,27 +70,25 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
}
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
# Overriding to set `norm_num_groups` needs to be different for this model.
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
@@ -107,39 +102,74 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
# Overriding since the UNet3D outputs a different structure.
|
||||
@torch.no_grad()
|
||||
def test_determinism(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
assert max_diff <= 1e-5
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
assert output.shape == output_2.shape, "Shape doesn't match"
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
|
||||
class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin):
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = 8
|
||||
@@ -162,22 +192,3 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
@@ -13,59 +13,42 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetControlNetXSModel
|
||||
main_input_name = "sample"
|
||||
class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNetControlNetXSModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
|
||||
conditioning_scale = 1
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"controlnet_cond": controlnet_cond,
|
||||
"conditioning_scale": conditioning_scale,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
def model_class(self):
|
||||
return UNetControlNetXSModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"sample_size": 16,
|
||||
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
@@ -80,11 +63,23 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
"ctrl_max_norm_num_groups": 2,
|
||||
"ctrl_conditioning_embedding_out_channels": (2, 2),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
"controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device),
|
||||
"conditioning_scale": 1,
|
||||
}
|
||||
|
||||
def get_dummy_unet(self):
|
||||
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
"""Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter."""
|
||||
return UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
@@ -99,10 +94,16 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
)
|
||||
|
||||
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
|
||||
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
|
||||
"""Build the ControlNetXS-Adapter from a UNet."""
|
||||
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
|
||||
|
||||
|
||||
class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32
|
||||
pass
|
||||
|
||||
def test_from_unet(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
@@ -115,7 +116,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
|
||||
|
||||
# # check unet
|
||||
# everything expect down,mid,up blocks
|
||||
# everything except down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
"time_embedding",
|
||||
"conv_in",
|
||||
@@ -152,7 +153,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
|
||||
|
||||
# # check controlnet
|
||||
# everything expect down,mid,up blocks
|
||||
# everything except down,mid,up blocks
|
||||
modules_from_controlnet = {
|
||||
"controlnet_cond_embedding": "controlnet_cond_embedding",
|
||||
"conv_in": "ctrl_conv_in",
|
||||
@@ -193,12 +194,12 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
for p in module.parameters():
|
||||
assert p.requires_grad
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
model = UNetControlNetXSModel(**init_dict)
|
||||
model.freeze_unet_params()
|
||||
|
||||
# # check unet
|
||||
# everything expect down,mid,up blocks
|
||||
# everything except down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
model.base_time_embedding,
|
||||
model.base_conv_in,
|
||||
@@ -236,7 +237,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
assert_frozen(u.upsamplers)
|
||||
|
||||
# # check controlnet
|
||||
# everything expect down,mid,up blocks
|
||||
# everything except down,mid,up blocks
|
||||
modules_from_controlnet = [
|
||||
model.controlnet_cond_embedding,
|
||||
model.ctrl_conv_in,
|
||||
@@ -267,16 +268,6 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
for u in model.up_blocks:
|
||||
assert_unfrozen(u.ctrl_to_base)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@is_flaky
|
||||
def test_forward_no_control(self):
|
||||
unet = self.get_dummy_unet()
|
||||
@@ -287,7 +278,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
unet = unet.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ = self.dummy_input
|
||||
input_ = self.get_dummy_inputs()
|
||||
|
||||
control_specific_input = ["controlnet_cond", "conditioning_scale"]
|
||||
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
|
||||
@@ -312,7 +303,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
model = model.to(torch_device)
|
||||
model_mix_time = model_mix_time.to(torch_device)
|
||||
|
||||
input_ = self.dummy_input
|
||||
input_ = self.get_dummy_inputs()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**input_).sample
|
||||
@@ -320,7 +311,14 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
|
||||
assert output.shape == output_mix_time.shape
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
|
||||
pass
|
||||
|
||||
class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import UNetSpatioTemporalConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -28,45 +28,34 @@ from ...testing_utils import (
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetSpatioTemporalConditionModel
|
||||
main_input_name = "sample"
|
||||
class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNetSpatioTemporalConditionModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_time_ids": self._get_add_time_ids(),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (2, 2, 4, 32, 32)
|
||||
def model_class(self):
|
||||
return UNetSpatioTemporalConditionModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return 6
|
||||
@@ -83,8 +72,8 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
def addition_time_embed_dim(self):
|
||||
return 32
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
@@ -103,8 +92,23 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
|
||||
"addition_time_embed_dim": self.addition_time_embed_dim,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 2
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_time_ids": self._get_add_time_ids(),
|
||||
}
|
||||
|
||||
def _get_add_time_ids(self, do_classifier_free_guidance=True):
|
||||
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
|
||||
@@ -124,43 +128,15 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
|
||||
return add_time_ids
|
||||
|
||||
@unittest.skip("Number of Norm Groups is not configurable")
|
||||
|
||||
class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Number of Norm Groups is not configurable")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Deprecated functionality")
|
||||
def test_model_attention_slicing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_use_linear_projection(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_simple_projection(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_model_with_num_attention_heads_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -173,12 +149,13 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["cross_attention_dim"] = (32, 32)
|
||||
|
||||
@@ -192,27 +169,13 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
num_attention_heads = (8, 16)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, num_attention_heads=num_attention_heads
|
||||
)
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_pickle(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
|
||||
@@ -225,3 +188,33 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
|
||||
class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin):
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
|
||||
class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -55,6 +55,20 @@ from ..test_torch_compile_utils import QuantCompileTests
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
return minor >= 9
|
||||
return major >= 9
|
||||
elif torch.xpu.is_available():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -64,12 +78,17 @@ if is_torch_available():
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.quantization import (
|
||||
Float8WeightOnlyConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
)
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
from torchao.quantization.quant_primitives import MappingType
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
|
||||
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -80,53 +99,30 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
"""
|
||||
Makes sure the config format is properly set
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only")
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(version=2))
|
||||
torchao_orig_config = quantization_config.to_dict()
|
||||
|
||||
for key in torchao_orig_config:
|
||||
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
|
||||
self.assertIn("quant_type", torchao_orig_config)
|
||||
self.assertIn("quant_method", torchao_orig_config)
|
||||
|
||||
def test_post_init_check(self):
|
||||
"""
|
||||
Test kwargs validations in TorchAoConfig
|
||||
Test that non-AOBaseConfig types are rejected
|
||||
"""
|
||||
_ = TorchAoConfig("int4_weight_only")
|
||||
with self.assertRaisesRegex(ValueError, "is not supported"):
|
||||
_ = TorchAoConfig("uint8")
|
||||
_ = TorchAoConfig(Int4WeightOnlyConfig())
|
||||
with self.assertRaises(TypeError):
|
||||
_ = TorchAoConfig("int4_weight_only")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
|
||||
_ = TorchAoConfig("int4_weight_only", group_size1=32)
|
||||
with self.assertRaises(TypeError):
|
||||
_ = TorchAoConfig(42)
|
||||
|
||||
def test_repr(self):
|
||||
"""
|
||||
Check that there is no error in the repr
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
|
||||
expected_repr = """TorchAoConfig {
|
||||
"modules_to_not_convert": [
|
||||
"conv"
|
||||
],
|
||||
"quant_method": "torchao",
|
||||
"quant_type": "int4_weight_only",
|
||||
"quant_type_kwargs": {
|
||||
"group_size": 8
|
||||
}
|
||||
}""".replace(" ", "").replace("\n", "")
|
||||
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
|
||||
self.assertEqual(quantization_repr, expected_repr)
|
||||
|
||||
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
|
||||
expected_repr = """TorchAoConfig {
|
||||
"modules_to_not_convert": null,
|
||||
"quant_method": "torchao",
|
||||
"quant_type": "int4dq",
|
||||
"quant_type_kwargs": {
|
||||
"act_mapping_type": "SYMMETRIC",
|
||||
"group_size": 64
|
||||
}
|
||||
}""".replace(" ", "").replace("\n", "")
|
||||
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
|
||||
self.assertEqual(quantization_repr, expected_repr)
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(version=2), modules_to_not_convert=["conv"])
|
||||
quantization_repr = repr(quantization_config)
|
||||
self.assertIn("TorchAoConfig", quantization_repr)
|
||||
self.assertIn("torchao", quantization_repr)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@@ -234,79 +230,30 @@ class TorchAoTest(unittest.TestCase):
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
# fmt: off
|
||||
QUANTIZATION_TYPES_TO_TEST = [
|
||||
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
|
||||
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
|
||||
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
(Int4WeightOnlyConfig(version=2), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
|
||||
(Int8DynamicActivationIntxWeightConfig(version=2), np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
|
||||
(Int8WeightOnlyConfig(version=2), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
(Int8DynamicActivationInt8WeightConfig(version=2), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
(IntxWeightOnlyConfig(dtype=torch.uint4, group_size=16, version=2), np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
(IntxWeightOnlyConfig(dtype=torch.uint7, group_size=16, version=2), np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
]
|
||||
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
if _is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
# =====
|
||||
# The following lead to an internal torch error:
|
||||
# RuntimeError: mat2 shape (32x4 must be divisible by 16
|
||||
# Skip these for now; TODO(aryan): investigate later
|
||||
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
# Cutlass fails to initialize for below
|
||||
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e5m2), np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quant_kwargs = {}
|
||||
if quantization_name in ["uint4wo", "uint7wo"]:
|
||||
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
|
||||
quant_kwargs.update({"group_size": 16})
|
||||
quantization_config = TorchAoConfig(
|
||||
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
|
||||
)
|
||||
for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"])
|
||||
self._test_quant_type(quantization_config, expected_slice, model_id)
|
||||
|
||||
@unittest.skip("Skipping floatx quantization tests")
|
||||
def test_floatx_quantization(self):
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
|
||||
self._test_quant_type(
|
||||
quantization_config,
|
||||
np.array(
|
||||
[
|
||||
0.4648,
|
||||
0.5195,
|
||||
0.5547,
|
||||
0.4180,
|
||||
0.4434,
|
||||
0.6445,
|
||||
0.4316,
|
||||
0.4531,
|
||||
0.5625,
|
||||
]
|
||||
),
|
||||
model_id,
|
||||
)
|
||||
else:
|
||||
# Make sure the correct error is thrown
|
||||
with self.assertRaisesRegex(ValueError, "Please downgrade"):
|
||||
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
|
||||
|
||||
def test_int4wo_quant_bfloat16_conversion(self):
|
||||
"""
|
||||
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -361,7 +308,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
else:
|
||||
expected_slice = expected_slice_offload
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -385,7 +332,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
@@ -406,7 +353,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(), modules_to_not_convert=["transformer_blocks.0"])
|
||||
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -422,7 +369,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
quantized_layer = quantized_model_with_not_convert.proj_out
|
||||
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
|
||||
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -436,7 +383,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(size_quantized < size_quantized_with_not_convert)
|
||||
|
||||
def test_training(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -470,7 +417,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
def test_torch_compile(self):
|
||||
r"""Test that verifies if torch.compile works with torchao quantization."""
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config, model_id=model_id)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device)
|
||||
@@ -491,11 +438,15 @@ class TorchAoTest(unittest.TestCase):
|
||||
memory footprint of the converted model and the class type of the linear layers of the converted models
|
||||
"""
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
|
||||
transformer_int4wo = self.get_dummy_components(TorchAoConfig(Int4WeightOnlyConfig()), model_id=model_id)[
|
||||
"transformer"
|
||||
]
|
||||
transformer_int4wo_gs32 = self.get_dummy_components(
|
||||
TorchAoConfig("int4wo", group_size=32), model_id=model_id
|
||||
TorchAoConfig(Int4WeightOnlyConfig(group_size=32)), model_id=model_id
|
||||
)["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[
|
||||
"transformer"
|
||||
]
|
||||
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
|
||||
|
||||
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
|
||||
@@ -553,20 +504,22 @@ class TorchAoTest(unittest.TestCase):
|
||||
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
|
||||
del transformer_bf16
|
||||
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[
|
||||
"transformer"
|
||||
]
|
||||
transformer_int8wo.to(torch_device)
|
||||
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
|
||||
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
|
||||
|
||||
def test_wrong_config(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(TypeError):
|
||||
self.get_dummy_components(TorchAoConfig("int42"))
|
||||
|
||||
def test_sequential_cpu_offload(self):
|
||||
r"""
|
||||
A test that checks if inference runs as expected when sequential cpu offloading is enabled.
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
@@ -595,8 +548,8 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
|
||||
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
|
||||
def get_dummy_model(self, quant_type, device=None):
|
||||
quantization_config = TorchAoConfig(quant_type)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
@@ -632,8 +585,8 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
|
||||
def _test_original_model_expected_slice(self, quant_type, expected_slice):
|
||||
quantized_model = self.get_dummy_model(quant_type, torch_device)
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
@@ -641,8 +594,8 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
|
||||
def _check_serialization_expected_slice(self, quant_type, expected_slice, device):
|
||||
quantized_model = self.get_dummy_model(quant_type, device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
@@ -662,40 +615,39 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
def test_int_a8w8_accelerator(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
quant_type = Int8DynamicActivationInt8WeightConfig()
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
def test_int_a16w8_accelerator(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
quant_type = Int8WeightOnlyConfig()
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
def test_int_a8w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
quant_type = Int8DynamicActivationInt8WeightConfig()
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
def test_int_a16w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
quant_type = Int8WeightOnlyConfig()
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
def test_aobase_config(self):
|
||||
quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
|
||||
quant_type = Int8WeightOnlyConfig()
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@@ -817,29 +769,25 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
def test_quantization(self):
|
||||
# fmt: off
|
||||
QUANTIZATION_TYPES_TO_TEST = [
|
||||
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
|
||||
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
|
||||
(Int8WeightOnlyConfig(), np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
|
||||
(Int8DynamicActivationInt8WeightConfig(), np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
|
||||
]
|
||||
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
if _is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
|
||||
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
|
||||
])
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
|
||||
for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"])
|
||||
self._test_quant_type(quantization_config, expected_slice)
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_synchronize(torch_device)
|
||||
|
||||
def test_serialization_int8wo(self):
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -876,7 +824,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
def test_memory_footprint_int4wo(self):
|
||||
# The original checkpoints are in bf16 and about 24 GB
|
||||
expected_memory_in_gb = 6.0
|
||||
quantization_config = TorchAoConfig("int4wo")
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig())
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
@@ -891,7 +839,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
def test_memory_footprint_int8wo(self):
|
||||
# The original checkpoints are in bf16 and about 24 GB
|
||||
expected_memory_in_gb = 12.0
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
|
||||
Reference in New Issue
Block a user