Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
cc605aa33e remove syncs before denoising in Kontext 2025-06-27 10:52:57 +05:30
32 changed files with 599 additions and 811 deletions

View File

@@ -75,6 +75,10 @@ jobs:
- diffusers-pytorch-cuda
- diffusers-pytorch-xformers-cuda
- diffusers-pytorch-minimum-cuda
- diffusers-flax-cpu
- diffusers-flax-tpu
- diffusers-onnxruntime-cpu
- diffusers-onnxruntime-cuda
- diffusers-doc-builder
steps:

View File

@@ -321,6 +321,55 @@ jobs:
name: torch_minimum_version_cuda_test_reports
path: reports
run_nightly_onnx_tests:
name: Nightly ONNXRuntime CUDA tests on Ubuntu
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-onnxruntime-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
- name: Run Nightly ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_onnx_cuda \
--report-log=tests_onnx_cuda.log \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_onnx_cuda_stats.txt
cat reports/tests_onnx_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: tests_onnx_cuda_reports
path: reports
run_nightly_quantization_tests:
name: Torch quantization nightly tests
strategy:
@@ -436,6 +485,57 @@ jobs:
name: torch_cuda_pipeline_level_quant_reports
path: reports
run_flax_tpu_tests:
name: Nightly Flax TPU Tests
runs-on:
group: gcp-ct5lp-hightpu-8t
if: github.event_name == 'schedule'
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
- name: Run nightly Flax TPU tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_flax_tpu \
--report-log=tests_flax_tpu.log \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_flax_tpu_stats.txt
cat reports/tests_flax_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: flax_tpu_test_reports
path: reports
generate_consolidated_report:
name: Generate Consolidated Test Report
needs: [
@@ -445,9 +545,9 @@ jobs:
run_big_gpu_torch_tests,
run_nightly_quantization_tests,
run_nightly_pipeline_level_quantization_tests,
# run_nightly_onnx_tests,
run_nightly_onnx_tests,
torch_minimum_version_cuda_tests,
# run_flax_tpu_tests
run_flax_tpu_tests
]
if: always()
runs-on:

View File

@@ -87,6 +87,11 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_models_schedulers
- name: Fast Flax CPU tests
framework: flax
runner: aws-general-8-plus
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: PyTorch Example CPU tests
framework: pytorch_examples
runner: aws-general-8-plus
@@ -142,6 +147,15 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
- name: Run fast Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |

View File

@@ -159,6 +159,102 @@ jobs:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
flax_tpu_tests:
name: Flax TPU Tests
runs-on:
group: gcp-ct5lp-hightpu-8t
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run Flax TPU tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_flax_tpu \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_flax_tpu_stats.txt
cat reports/tests_flax_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: flax_tpu_test_reports
path: reports
onnx_cuda_tests:
name: ONNX CUDA Tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-onnxruntime-cuda
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_onnx_cuda \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_onnx_cuda_stats.txt
cat reports/tests_onnx_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: onnx_cuda_test_reports
path: reports
run_torch_compile_tests:
name: PyTorch Compile CUDA tests

View File

@@ -33,6 +33,16 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu
- name: Fast Flax CPU tests on Ubuntu
framework: flax
runner: aws-general-8-plus
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
framework: onnxruntime
runner: aws-general-8-plus
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: aws-general-8-plus
@@ -77,6 +87,24 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run fast Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run fast ONNXRuntime CPU tests
if: ${{ matrix.config.framework == 'onnxruntime' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |

View File

@@ -1,7 +1,12 @@
name: Fast mps tests on main
on:
workflow_dispatch:
push:
branches:
- main
paths:
- "src/diffusers/**.py"
- "tests/**.py"
env:
DIFFUSERS_IS_CI: yes

View File

@@ -213,6 +213,101 @@ jobs:
with:
name: torch_minimum_version_cuda_test_reports
path: reports
flax_tpu_tests:
name: Flax TPU Tests
runs-on: docker-tpu
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run slow Flax TPU tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_flax_tpu \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_flax_tpu_stats.txt
cat reports/tests_flax_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: flax_tpu_test_reports
path: reports
onnx_cuda_tests:
name: ONNX CUDA Tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-onnxruntime-cuda
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run slow ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_onnx_cuda \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_onnx_cuda_stats.txt
cat reports/tests_onnx_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: onnx_cuda_test_reports
path: reports
run_torch_compile_tests:
name: PyTorch Compile CUDA tests

View File

@@ -24,31 +24,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
</Tip>
## Loading original format checkpoints
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
```python
import torch
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
transformer = CosmosTransformer3DModel.from_single_file(
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
output = pipe(
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
).images[0]
output.save("output.png")
```
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline

View File

@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
```py
# pip install ftfy
import torch
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
from diffusers import WanPipeline, AutoModel
vae = AutoencoderKLWan.from_single_file(
vae = AutoModel.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
transformer = WanTransformer3DModel.from_single_file(
transformer = AutoModel.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
torch_dtype=torch.bfloat16
)

View File

@@ -315,8 +315,6 @@ pipeline.load_lora_weights(
> [!TIP]
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
## Merge

View File

@@ -95,6 +95,7 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
# "extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",

View File

@@ -14,8 +14,6 @@
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
@@ -48,24 +46,6 @@ _SUPPORTED_PYTORCH_LAYERS = (
# fmt: on
class GroupOffloadingType(str, Enum):
BLOCK_LEVEL = "block_level"
LEAF_LEVEL = "leaf_level"
@dataclass
class GroupOffloadingConfig:
onload_device: torch.device
offload_device: torch.device
offload_type: GroupOffloadingType
non_blocking: bool
record_stream: bool
low_cpu_mem_usage: bool
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
class ModuleGroup:
def __init__(
self,
@@ -308,12 +288,9 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
self.group = group
self.next_group = next_group
self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
@@ -459,7 +436,7 @@ def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
offload_type: str = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
@@ -501,7 +478,7 @@ def apply_group_offloading(
The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
offload_type (`str`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
@@ -544,8 +521,6 @@ def apply_group_offloading(
```
"""
offload_type = GroupOffloadingType(offload_type)
stream = None
if use_stream:
if torch.cuda.is_available():
@@ -557,45 +532,84 @@ def apply_group_offloading(
if not use_stream and record_stream:
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
config = GroupOffloadingConfig(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
)
_apply_group_offloading(module, config)
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
_apply_group_offloading_block_level(module, config)
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
_apply_group_offloading_leaf_level(module, config)
_apply_group_offloading_block_level(
module=module,
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
module=module,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
assert False
raise ValueError(f"Unsupported offload_type: {offload_type}")
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
def _apply_group_offloading_block_level(
module: torch.nn.Module,
num_blocks_per_group: int,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
"""
if config.stream is not None and config.num_blocks_per_group != 1:
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
if stream is not None and num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
)
config.num_blocks_per_group = 1
num_blocks_per_group = 1
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -607,19 +621,19 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
modules_with_group_offloading.add(name)
continue
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
matched_module_groups.append(group)
@@ -629,7 +643,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None, config=config)
_apply_group_offloading_hook(group_module, group, None)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -644,9 +658,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
@@ -656,19 +670,54 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
record_stream=False,
onload_self=True,
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
def _apply_group_offloading_leaf_level(
module: torch.nn.Module,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
reduce memory usage without any performance degradation.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
@@ -676,18 +725,18 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
continue
group = ModuleGroup(
modules=[submodule],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None, config=config)
_apply_group_offloading_hook(submodule, group, None)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -718,32 +767,33 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
parameters = parent_to_parameters.get(name, [])
buffers = parent_to_buffers.get(name, [])
parent_module = module_dict[name]
assert getattr(parent_module, "_diffusers_hook", None) is None
group = ModuleGroup(
modules=[],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_device=offload_device,
onload_device=onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk_path=config.offload_to_disk_path,
offload_to_disk_path=offload_to_disk_path,
parameters=parameters,
buffers=buffers,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None, config=config)
_apply_group_offloading_hook(parent_module, group, None)
if config.stream is not None:
if stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
# execution order and apply prefetching in the correct order.
unmatched_group = ModuleGroup(
modules=[],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=None,
@@ -751,25 +801,23 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
non_blocking=False,
stream=None,
record_stream=False,
low_cpu_mem_usage=config.low_cpu_mem_usage,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, next_group)
registry.register_hook(hook, _GROUP_OFFLOADING)
@@ -777,15 +825,13 @@ def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, next_group)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
@@ -852,48 +898,15 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
)
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook"):
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None:
return group_offloading_hook
return None
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
return top_level_group_offload_hook is not None
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return True
return False
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
if top_level_group_offload_hook is not None:
return top_level_group_offload_hook.config.onload_device
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
raise ValueError("Group offloading is not enabled for the provided module.")
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
r"""
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
case where user has applied group offloading at multiple levels, this function will not work as expected.
There is some performance penalty associated with doing this when non-default streams are used, because we need to
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
"""
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
if top_level_group_offload_hook is None:
return
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
_apply_group_offloading(module, top_level_group_offload_hook.config)

View File

@@ -25,7 +25,6 @@ import torch.nn as nn
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
@@ -392,9 +391,7 @@ def _load_lora_into_text_encoder(
adapter_name = get_adapter_name(text_encoder)
# <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
_pipeline
)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
@@ -413,10 +410,6 @@ def _load_lora_into_text_encoder(
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
if prefix is not None and not state_dict:
@@ -440,36 +433,30 @@ def _func_optionally_disable_offloading(_pipeline):
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module):
continue
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
if not hasattr(component, "_hf_hook"):
continue
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
is_sequential_cpu_offload = is_sequential_cpu_offload or (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
if is_sequential_cpu_offload or is_model_cpu_offload:
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
if is_sequential_cpu_offload or is_model_cpu_offload:
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
class LoraBaseMixin:

View File

@@ -22,7 +22,6 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
@@ -244,29 +243,20 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
# create LoraConfig
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
# create LoraConfig
lora_config = _create_lora_config(
state_dict,
network_alphas,
metadata,
rank,
model_state_dict=self.state_dict(),
adapter_name=adapter_name,
)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -357,10 +347,6 @@ class PeftAdapterMixin:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
if prefix is not None and not state_dict:
@@ -700,10 +686,6 @@ class PeftAdapterMixin:
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
if hasattr(self, "_hf_peft_config_loaded"):
self._hf_peft_config_loaded = None
_maybe_remove_and_reapply_group_offloading(self)
def disable_lora(self):
"""

View File

@@ -31,7 +31,6 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_cosmos_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
@@ -144,10 +143,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"CosmosTransformer3DModel": {
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}

View File

@@ -127,16 +127,6 @@ CHECKPOINT_KEY_NAMES = {
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
"cosmos-1.0": [
"net.x_embedder.proj.1.weight",
"net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
"net.extra_pos_embedder.pos_emb_h",
],
"cosmos-2.0": [
"net.x_embedder.proj.1.weight",
"net.blocks.0.self_attn.q_proj.weight",
"net.pos_embedder.dim_spatial_range",
],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -203,14 +193,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
"cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
"cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
"cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
}
# Use to configure model sample size when original config is provided
@@ -722,32 +704,11 @@ def infer_diffusers_model_type(checkpoint):
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
if x_embedder_shape[1] == 68:
model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
elif x_embedder_shape[1] == 72:
model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
else:
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
if x_embedder_shape[1] == 68:
model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
elif x_embedder_shape[1] == 72:
model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
else:
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
else:
model_type = "v1"
@@ -3518,116 +3479,3 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict
def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
def remove_keys_(key: str, state_dict):
state_dict.pop(key)
def rename_transformer_blocks_(key: str, state_dict):
block_index = int(key.split(".")[1].removeprefix("block"))
new_key = key
old_prefix = f"blocks.block{block_index}"
new_prefix = f"transformer_blocks.{block_index}"
new_key = new_prefix + new_key.removeprefix(old_prefix)
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
"t_embedder.1": "time_embed.t_embedder",
"affline_norm": "time_embed.norm",
".blocks.0.block.attn": ".attn1",
".blocks.1.block.attn": ".attn2",
".blocks.2.block": ".ff",
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
"to_q.0": "to_q",
"to_q.1": "norm_q",
"to_k.0": "to_k",
"to_k.1": "norm_k",
"to_v.0": "to_v",
"layer1": "net.0.proj",
"layer2": "net.2",
"proj.1": "proj",
"x_embedder": "patch_embed",
"extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
"blocks.block": rename_transformer_blocks_,
"logvar.0.freqs": remove_keys_,
"logvar.0.phases": remove_keys_,
"logvar.1.weight": remove_keys_,
"pos_embedder.seq": remove_keys_,
}
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"t_embedder.1": "time_embed.t_embedder",
"t_embedding_norm": "time_embed.norm",
"blocks": "transformer_blocks",
"adaln_modulation_self_attn.1": "norm1.linear_1",
"adaln_modulation_self_attn.2": "norm1.linear_2",
"adaln_modulation_cross_attn.1": "norm2.linear_1",
"adaln_modulation_cross_attn.2": "norm2.linear_2",
"adaln_modulation_mlp.1": "norm3.linear_1",
"adaln_modulation_mlp.2": "norm3.linear_2",
"self_attn": "attn1",
"cross_attn": "attn2",
"q_proj": "to_q",
"k_proj": "to_k",
"v_proj": "to_v",
"output_proj": "to_out.0",
"q_norm": "norm_q",
"k_norm": "norm_k",
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
"accum_video_sample_counter": remove_keys_,
"accum_image_sample_counter": remove_keys_,
"accum_iteration": remove_keys_,
"accum_train_in_hours": remove_keys_,
"pos_embedder.seq": remove_keys_,
"pos_embedder.dim_spatial_range": remove_keys_,
"pos_embedder.dim_temporal_range": remove_keys_,
"_extra_state": remove_keys_,
}
PREFIX_KEY = "net."
if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
else:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
state_dict_keys = list(converted_state_dict.keys())
for key in state_dict_keys:
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
state_dict_keys = list(converted_state_dict.keys())
for key in state_dict_keys:
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict

View File

@@ -22,7 +22,6 @@ import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
@@ -204,7 +203,6 @@ class UNet2DConditionLoadersMixin:
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -213,7 +211,7 @@ class UNet2DConditionLoadersMixin:
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
state_dict=state_dict,
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
@@ -232,9 +230,7 @@ class UNet2DConditionLoadersMixin:
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None:
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline=_pipeline
)
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
# only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors)
@@ -245,10 +241,6 @@ class UNet2DConditionLoadersMixin:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
@@ -315,7 +307,6 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0:
@@ -365,9 +356,7 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -400,7 +389,7 @@ class UNet2DConditionLoadersMixin:
if warn_msg:
logger.warning(warn_msg)
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
return is_model_cpu_offload, is_sequential_cpu_offload
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading

View File

@@ -14,8 +14,6 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Union
from torch import nn
from ..utils import logging
@@ -54,7 +52,7 @@ def _maybe_expand_lora_scales(
weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
model=unet,
unet.state_dict(),
default_scale=default_scale,
)
for weight_for_adapter in weight_scales
@@ -67,7 +65,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
model: nn.Module,
state_dict: None,
default_scale: float = 1.0,
):
"""
@@ -156,7 +154,6 @@ def _maybe_expand_lora_scales_for_one_adapter(
del scales[updown]
state_dict = model.state_dict()
for layer in scales.keys():
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(

View File

@@ -20,7 +20,6 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -378,7 +377,7 @@ class CosmosLearnablePositionalEmbed(nn.Module):
return (emb / norm).type_as(hidden_states)
class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
r"""
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).

View File

@@ -168,6 +168,7 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
print(f"Set timesteps: {self.timesteps}")
self._step_index = None
self._begin_index = None

View File

@@ -150,9 +150,7 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
@@ -182,6 +180,7 @@ def get_peft_kwargs(
else:
lora_alpha = set(network_alpha_dict.values()).pop()
# layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# for now we know that the "bias" keys are only associated with `lora_B`.
@@ -196,21 +195,6 @@ def get_peft_kwargs(
"use_dora": use_dora,
"lora_bias": lora_bias,
}
# Example: try load FusionX LoRA into Wan VACE
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
if exclude_modules:
if not is_peft_version(">=", "0.14.0"):
msg = """
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
https://github.com/huggingface/diffusers/issues/new
"""
logger.debug(msg)
else:
lora_config_kwargs.update({"exclude_modules": exclude_modules})
return lora_config_kwargs
@@ -310,7 +294,11 @@ def check_peft_version(min_version: str) -> None:
def _create_lora_config(
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
state_dict,
network_alphas,
metadata,
rank_pattern_dict,
is_unet: bool = True,
):
from peft import LoraConfig
@@ -318,12 +306,7 @@ def _create_lora_config(
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank_pattern_dict,
network_alpha_dict=network_alphas,
peft_state_dict=state_dict,
is_unet=is_unet,
model_state_dict=model_state_dict,
adapter_name=adapter_name,
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
)
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -388,27 +371,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
if warn_msg:
logger.warning(warn_msg)
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
"""
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
doesn't exist in `peft_state_dict`.
"""
if model_state_dict is None:
return
all_modules = set()
string_to_replace = f"{adapter_name}." if adapter_name else ""
for name in model_state_dict.keys():
if string_to_replace:
name = name.replace(string_to_replace, "")
if "." in name:
module_name = name.rsplit(".", 1)[0]
all_modules.add(module_name)
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
exclude_modules = list(all_modules - target_modules_set)
return exclude_modules

View File

@@ -16,7 +16,6 @@ import sys
import unittest
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
@@ -29,7 +28,6 @@ from diffusers import (
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_torch_accelerator,
)
@@ -129,13 +127,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

View File

@@ -18,17 +18,10 @@ import unittest
import numpy as np
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_torch_accelerator,
skip_mps,
torch_device,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
sys.path.append(".")
@@ -148,13 +141,6 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"Loading from saved checkpoints should give same results.",
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

View File

@@ -24,11 +24,7 @@ from diffusers import (
WanPipeline,
WanTransformer3DModel,
)
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
skip_mps,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
sys.path.append(".")

View File

@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
import re
@@ -40,7 +39,6 @@ from diffusers.utils.testing_utils import (
is_torch_version,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_transformers_version_greater,
skip_mps,
torch_device,
@@ -292,21 +290,9 @@ class PeftLoraLoaderMixinTests:
return modules_to_save
def _get_exclude_modules(self, pipe):
from diffusers.utils.peft_utils import _derive_exclude_modules
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
pipe.unload_lora_weights()
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
exclude_modules = _derive_exclude_modules(
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
)
return exclude_modules
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
def check_if_adapters_added_correctly(
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
@@ -358,7 +344,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -441,7 +427,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -497,7 +483,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -535,7 +521,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
@@ -567,7 +553,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -602,7 +588,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -653,7 +639,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -704,7 +690,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -747,7 +733,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -788,7 +774,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -832,7 +818,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
@@ -870,7 +856,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -906,7 +892,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
@@ -1023,7 +1009,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
@@ -1045,7 +1031,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
@@ -1772,7 +1758,7 @@ class PeftLoraLoaderMixinTests:
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1863,7 +1849,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
@@ -1950,7 +1936,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
@@ -2132,7 +2118,7 @@ class PeftLoraLoaderMixinTests:
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
@@ -2250,7 +2236,7 @@ class PeftLoraLoaderMixinTests:
)
pipe = self.pipeline_class(**components)
pipe, _ = self.add_adapters_to_pipeline(
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
@@ -2303,7 +2289,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -2322,77 +2308,6 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)
def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
# unload and then add.
pipe.unload_lora_weights()
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules(self):
"""
Test to check if `exclude_modules` works or not. It works in the following way:
we first create a pipeline and insert LoRA config into it. We then derive a `set`
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
state dict.
We then create a new LoRA config to include the `exclude_modules` and perform tests.
"""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
# only supported for `denoiser` now
pipe_cp = copy.deepcopy(pipe)
pipe_cp, _ = self.add_adapters_to_pipeline(
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
pipe_cp.to("cpu")
del pipe_cp
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
"LoRA should change outputs.",
)
self.assertTrue(
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
"Lora outputs should match.",
)
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
@@ -2440,73 +2355,3 @@ class PeftLoraLoaderMixinTests:
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
onload_device = torch_device
offload_device = torch.device("cpu")
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
# Test group offloading with load_lora_weights
denoiser.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=use_stream,
)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Test group offloading after removing the lora
pipe.unload_lora_weights()
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_2 is not None)
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
# Add the lora again and check if group offloading works
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_3 is not None)
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
# Skip this test if it is overwritten by child class. We need to do this because parameterized
# materializes the test methods on invocation which cannot be overridden.
return
self._test_group_offloading_inference_denoiser(offload_type, use_stream)

View File

@@ -1350,6 +1350,7 @@ class ModelTesterMixin:
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
@@ -2018,8 +2019,6 @@ class LoraHotSwappingForModelTesterMixin:
"""
different_shapes_for_compilation = None
def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2057,13 +2056,11 @@ class LoraHotSwappingForModelTesterMixin:
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
- optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -2113,30 +2110,19 @@ class LoraHotSwappingForModelTesterMixin:
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
model = torch.compile(model, mode="reduce-overhead")
with torch.inference_mode():
# additionally check if dynamic compilation works.
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name
name = "does-not-exist"
@@ -2254,23 +2240,3 @@ class LoraHotSwappingForModelTesterMixin:
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
@parameterized.expand([(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")
def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
different_shapes_for_compilation = self.different_shapes_for_compilation
if different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
# variable to represent input sizes that are the same. For more details,
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
torch.fx.experimental._config.use_duck_shape = False
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(
do_compile=True,
rank0=rank0,
rank1=rank1,
target_modules0=target_modules,
)

View File

@@ -186,10 +186,6 @@ class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)

View File

@@ -1378,6 +1378,7 @@ class PipelineTesterMixin:
for component in pipe_fp16.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
@@ -1385,20 +1386,17 @@ class PipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs)[0]
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs)[0]
if isinstance(output, torch.Tensor):
output = output.cpu()
output_fp16 = output_fp16.cpu()
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff
assert max_diff < 1e-2
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator

View File

@@ -98,14 +98,7 @@ class Base4bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
if not cls.is_deterministic_enabled:
torch.use_deterministic_algorithms(True)
@classmethod
def tearDownClass(cls):
if not cls.is_deterministic_enabled:
torch.use_deterministic_algorithms(False)
torch.use_deterministic_algorithms(True)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
@@ -873,17 +866,15 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1")
class Bnb4BitCompileTests(QuantCompileTests):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer", "text_encoder_2"],
)
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer", "text_encoder_2"],
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -892,7 +883,5 @@ class Bnb4BitCompileTests(QuantCompileTests):
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, use_stream=True
)
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)

View File

@@ -99,14 +99,7 @@ class Base8bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
if not cls.is_deterministic_enabled:
torch.use_deterministic_algorithms(True)
@classmethod
def tearDownClass(cls):
if not cls.is_deterministic_enabled:
torch.use_deterministic_algorithms(False)
torch.use_deterministic_algorithms(True)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
@@ -838,13 +831,11 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0")
class Bnb8BitCompileTests(QuantCompileTests):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
)
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -856,7 +847,7 @@ class Bnb8BitCompileTests(QuantCompileTests):
)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)

View File

@@ -24,11 +24,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
@require_torch_gpu
@slow
class QuantCompileTests(unittest.TestCase):
@property
def quantization_config(self):
raise NotImplementedError(
"This property should be implemented in the subclass to return the appropriate quantization config."
)
quantization_config = None
def setUp(self):
super().setUp()
@@ -68,9 +64,7 @@ class QuantCompileTests(unittest.TestCase):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_group_offload_leaf(
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
):
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
torch._dynamo.config.cache_size_limit = 10000
pipe = self._init_pipeline(quantization_config, torch_dtype)
@@ -78,7 +72,8 @@ class QuantCompileTests(unittest.TestCase):
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": use_stream,
"use_stream": True,
"non_blocking": True,
}
pipe.transformer.enable_group_offload(**group_offload_kwargs)
pipe.transformer.compile()

View File

@@ -19,7 +19,6 @@ import unittest
from typing import List
import numpy as np
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
@@ -30,7 +29,6 @@ from diffusers import (
TorchAoConfig,
)
from diffusers.models.attention_processor import Attention
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_synchronize,
@@ -46,8 +44,6 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_torch_compile_utils import QuantCompileTests
enable_full_determinism()
@@ -629,53 +625,6 @@ class TorchAoSerializationTest(unittest.TestCase):
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
},
)
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
@unittest.skip(
"""
For `use_stream=False`:
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
For `use_stream=True`:
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
"""
)
@parameterized.expand([False, True])
def test_torch_compile_with_group_offload_leaf(self):
# For use_stream=False:
# If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
# When running with compilation, the error ends up being different:
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
# Looks like something that will have to be looked into upstream.
# for linear layers, weight.tensor_impl shows cuda... but:
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
# For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator