mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-24 21:34:55 +08:00
Compare commits
1 Commits
wan-sf-doc
...
kontext-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc605aa33e |
4
.github/workflows/build_docker_images.yml
vendored
4
.github/workflows/build_docker_images.yml
vendored
@@ -75,6 +75,10 @@ jobs:
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-pytorch-xformers-cuda
|
||||
- diffusers-pytorch-minimum-cuda
|
||||
- diffusers-flax-cpu
|
||||
- diffusers-flax-tpu
|
||||
- diffusers-onnxruntime-cpu
|
||||
- diffusers-onnxruntime-cuda
|
||||
- diffusers-doc-builder
|
||||
|
||||
steps:
|
||||
|
||||
104
.github/workflows/nightly_tests.yml
vendored
104
.github/workflows/nightly_tests.yml
vendored
@@ -321,6 +321,55 @@ jobs:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_nightly_onnx_tests:
|
||||
name: Nightly ONNXRuntime CUDA tests on Ubuntu
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run Nightly ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
--report-log=tests_onnx_cuda.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: tests_onnx_cuda_reports
|
||||
path: reports
|
||||
|
||||
run_nightly_quantization_tests:
|
||||
name: Torch quantization nightly tests
|
||||
strategy:
|
||||
@@ -436,6 +485,57 @@ jobs:
|
||||
name: torch_cuda_pipeline_level_quant_reports
|
||||
path: reports
|
||||
|
||||
run_flax_tpu_tests:
|
||||
name: Nightly Flax TPU Tests
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
if: github.event_name == 'schedule'
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run nightly Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
--report-log=tests_flax_tpu.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
generate_consolidated_report:
|
||||
name: Generate Consolidated Test Report
|
||||
needs: [
|
||||
@@ -445,9 +545,9 @@ jobs:
|
||||
run_big_gpu_torch_tests,
|
||||
run_nightly_quantization_tests,
|
||||
run_nightly_pipeline_level_quantization_tests,
|
||||
# run_nightly_onnx_tests,
|
||||
run_nightly_onnx_tests,
|
||||
torch_minimum_version_cuda_tests,
|
||||
# run_flax_tpu_tests
|
||||
run_flax_tpu_tests
|
||||
]
|
||||
if: always()
|
||||
runs-on:
|
||||
|
||||
14
.github/workflows/pr_tests.yml
vendored
14
.github/workflows/pr_tests.yml
vendored
@@ -87,6 +87,11 @@ jobs:
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: PyTorch Example CPU tests
|
||||
framework: pytorch_examples
|
||||
runner: aws-general-8-plus
|
||||
@@ -142,6 +147,15 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests
|
||||
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
|
||||
96
.github/workflows/push_tests.yml
vendored
96
.github/workflows/push_tests.yml
vendored
@@ -159,6 +159,102 @@ jobs:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
name: ONNX CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: onnx_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
|
||||
28
.github/workflows/push_tests_fast.yml
vendored
28
.github/workflows/push_tests_fast.yml
vendored
@@ -33,6 +33,16 @@ jobs:
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu
|
||||
- name: Fast Flax CPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
- name: PyTorch Example CPU tests on Ubuntu
|
||||
framework: pytorch_examples
|
||||
runner: aws-general-8-plus
|
||||
@@ -77,6 +87,24 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast ONNXRuntime CPU tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
|
||||
7
.github/workflows/push_tests_mps.yml
vendored
7
.github/workflows/push_tests_mps.yml
vendored
@@ -1,7 +1,12 @@
|
||||
name: Fast mps tests on main
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
|
||||
95
.github/workflows/release_tests_fast.yml
vendored
95
.github/workflows/release_tests_fast.yml
vendored
@@ -213,6 +213,101 @@ jobs:
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
name: ONNX CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: onnx_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
@@ -24,31 +24,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
|
||||
</Tip>
|
||||
|
||||
## Loading original format checkpoints
|
||||
|
||||
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
|
||||
|
||||
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
|
||||
transformer = CosmosTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
).images[0]
|
||||
output.save("output.png")
|
||||
```
|
||||
|
||||
## CosmosTextToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosTextToWorldPipeline
|
||||
|
||||
@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
```py
|
||||
# pip install ftfy
|
||||
import torch
|
||||
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
|
||||
from diffusers import WanPipeline, AutoModel
|
||||
|
||||
vae = AutoencoderKLWan.from_single_file(
|
||||
vae = AutoModel.from_single_file(
|
||||
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
|
||||
)
|
||||
transformer = WanTransformer3DModel.from_single_file(
|
||||
transformer = AutoModel.from_single_file(
|
||||
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
@@ -315,8 +315,6 @@ pipeline.load_lora_weights(
|
||||
> [!TIP]
|
||||
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
|
||||
|
||||
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
|
||||
|
||||
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
|
||||
|
||||
## Merge
|
||||
|
||||
@@ -95,6 +95,7 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
||||
"mlp.layer1": "ff.net.0.proj",
|
||||
"mlp.layer2": "ff.net.2",
|
||||
"x_embedder.proj.1": "patch_embed.proj",
|
||||
# "extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import safetensors.torch
|
||||
@@ -48,24 +46,6 @@ _SUPPORTED_PYTORCH_LAYERS = (
|
||||
# fmt: on
|
||||
|
||||
|
||||
class GroupOffloadingType(str, Enum):
|
||||
BLOCK_LEVEL = "block_level"
|
||||
LEAF_LEVEL = "leaf_level"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupOffloadingConfig:
|
||||
onload_device: torch.device
|
||||
offload_device: torch.device
|
||||
offload_type: GroupOffloadingType
|
||||
non_blocking: bool
|
||||
record_stream: bool
|
||||
low_cpu_mem_usage: bool
|
||||
num_blocks_per_group: Optional[int] = None
|
||||
offload_to_disk_path: Optional[str] = None
|
||||
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -308,12 +288,9 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(
|
||||
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
|
||||
) -> None:
|
||||
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
self.config = config
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
@@ -459,7 +436,7 @@ def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
||||
offload_type: str = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -501,7 +478,7 @@ def apply_group_offloading(
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
||||
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
|
||||
offload_type (`str`, defaults to "block_level"):
|
||||
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
||||
"block_level".
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
@@ -544,8 +521,6 @@ def apply_group_offloading(
|
||||
```
|
||||
"""
|
||||
|
||||
offload_type = GroupOffloadingType(offload_type)
|
||||
|
||||
stream = None
|
||||
if use_stream:
|
||||
if torch.cuda.is_available():
|
||||
@@ -557,45 +532,84 @@ def apply_group_offloading(
|
||||
|
||||
if not use_stream and record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
||||
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
|
||||
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
config = GroupOffloadingConfig(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
)
|
||||
_apply_group_offloading(module, config)
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
|
||||
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
|
||||
_apply_group_offloading_block_level(module, config)
|
||||
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
|
||||
_apply_group_offloading_leaf_level(module, config)
|
||||
_apply_group_offloading_block_level(
|
||||
module=module,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(
|
||||
module=module,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
|
||||
|
||||
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
def _apply_group_offloading_block_level(
|
||||
module: torch.nn.Module,
|
||||
num_blocks_per_group: int,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
||||
"""
|
||||
|
||||
if config.stream is not None and config.num_blocks_per_group != 1:
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
if stream is not None and num_blocks_per_group != 1:
|
||||
logger.warning(
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
|
||||
)
|
||||
config.num_blocks_per_group = 1
|
||||
num_blocks_per_group = 1
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -607,19 +621,19 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
modules_with_group_offloading.add(name)
|
||||
continue
|
||||
|
||||
for i in range(0, len(submodule), config.num_blocks_per_group):
|
||||
current_modules = submodule[i : i + config.num_blocks_per_group]
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -629,7 +643,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, None, config=config)
|
||||
_apply_group_offloading_hook(group_module, group, None)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
@@ -644,9 +658,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
@@ -656,19 +670,54 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
if stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
def _apply_group_offloading_leaf_level(
|
||||
module: torch.nn.Module,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
||||
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
||||
reduce memory usage without any performance degradation.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
@@ -676,18 +725,18 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
continue
|
||||
group = ModuleGroup(
|
||||
modules=[submodule],
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_leader=submodule,
|
||||
onload_leader=submodule,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None, config=config)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
||||
@@ -718,32 +767,33 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
parameters = parent_to_parameters.get(name, [])
|
||||
buffers = parent_to_buffers.get(name, [])
|
||||
parent_module = module_dict[name]
|
||||
assert getattr(parent_module, "_diffusers_hook", None) is None
|
||||
group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_leader=parent_module,
|
||||
onload_leader=parent_module,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None, config=config)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
|
||||
if config.stream is not None:
|
||||
if stream is not None:
|
||||
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
||||
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
||||
# execution order and apply prefetching in the correct order.
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=None,
|
||||
@@ -751,25 +801,23 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
|
||||
def _apply_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
|
||||
@@ -777,15 +825,13 @@ def _apply_lazy_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
||||
@@ -852,48 +898,15 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
|
||||
)
|
||||
|
||||
|
||||
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook"):
|
||||
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
|
||||
if group_offloading_hook is not None:
|
||||
return group_offloading_hook
|
||||
return None
|
||||
|
||||
|
||||
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
return top_level_group_offload_hook is not None
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
if top_level_group_offload_hook is not None:
|
||||
return top_level_group_offload_hook.config.onload_device
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
|
||||
|
||||
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
|
||||
r"""
|
||||
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
|
||||
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
|
||||
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
|
||||
|
||||
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
|
||||
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
|
||||
case where user has applied group offloading at multiple levels, this function will not work as expected.
|
||||
|
||||
There is some performance penalty associated with doing this when non-default streams are used, because we need to
|
||||
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
|
||||
"""
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
|
||||
if top_level_group_offload_hook is None:
|
||||
return
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
|
||||
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
|
||||
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
|
||||
|
||||
_apply_group_offloading(module, top_level_group_offload_hook.config)
|
||||
|
||||
@@ -25,7 +25,6 @@ import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
|
||||
from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -392,9 +391,7 @@ def _load_lora_into_text_encoder(
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# <Unsafe code
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
@@ -413,10 +410,6 @@ def _load_lora_into_text_encoder(
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
@@ -440,36 +433,30 @@ def _func_optionally_disable_offloading(_pipeline):
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if not isinstance(component, nn.Module):
|
||||
continue
|
||||
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
|
||||
if not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
|
||||
is_sequential_cpu_offload = is_sequential_cpu_offload or (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
if is_sequential_cpu_offload or is_model_cpu_offload:
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
for _, component in _pipeline.components.items():
|
||||
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
if is_sequential_cpu_offload or is_model_cpu_offload:
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -244,29 +243,20 @@ class PeftAdapterMixin:
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
metadata,
|
||||
rank,
|
||||
model_state_dict=self.state_dict(),
|
||||
adapter_name=adapter_name,
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||
# Now we remove any existing hooks to `_pipeline`.
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error.
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -357,10 +347,6 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
@@ -700,10 +686,6 @@ class PeftAdapterMixin:
|
||||
recurse_remove_peft_layers(self)
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
if hasattr(self, "_hf_peft_config_loaded"):
|
||||
self._hf_peft_config_loaded = None
|
||||
|
||||
_maybe_remove_and_reapply_group_offloading(self)
|
||||
|
||||
def disable_lora(self):
|
||||
"""
|
||||
|
||||
@@ -31,7 +31,6 @@ from .single_file_utils import (
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_chroma_transformer_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
@@ -144,10 +143,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"CosmosTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -127,16 +127,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||
"cosmos-1.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
"net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
|
||||
"net.extra_pos_embedder.pos_emb_h",
|
||||
],
|
||||
"cosmos-2.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
"net.blocks.0.self_attn.q_proj.weight",
|
||||
"net.pos_embedder.dim_spatial_range",
|
||||
],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -203,14 +193,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
|
||||
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
|
||||
"cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
|
||||
"cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
|
||||
"cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
|
||||
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
|
||||
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
||||
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -722,32 +704,11 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "wan-t2v-14B"
|
||||
else:
|
||||
model_type = "wan-i2v-14B"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||
model_type = "wan-t2v-14B"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
||||
model_type = "hidream"
|
||||
|
||||
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
|
||||
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
|
||||
if x_embedder_shape[1] == 68:
|
||||
model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
|
||||
elif x_embedder_shape[1] == 72:
|
||||
model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
|
||||
else:
|
||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
|
||||
|
||||
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
|
||||
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
|
||||
if x_embedder_shape[1] == 68:
|
||||
model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
|
||||
elif x_embedder_shape[1] == 72:
|
||||
model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
|
||||
else:
|
||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -3518,116 +3479,3 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
def remove_keys_(key: str, state_dict):
|
||||
state_dict.pop(key)
|
||||
|
||||
def rename_transformer_blocks_(key: str, state_dict):
|
||||
block_index = int(key.split(".")[1].removeprefix("block"))
|
||||
new_key = key
|
||||
old_prefix = f"blocks.block{block_index}"
|
||||
new_prefix = f"transformer_blocks.{block_index}"
|
||||
new_key = new_prefix + new_key.removeprefix(old_prefix)
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"affline_norm": "time_embed.norm",
|
||||
".blocks.0.block.attn": ".attn1",
|
||||
".blocks.1.block.attn": ".attn2",
|
||||
".blocks.2.block": ".ff",
|
||||
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
|
||||
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
|
||||
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
|
||||
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
|
||||
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
|
||||
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
|
||||
"to_q.0": "to_q",
|
||||
"to_q.1": "norm_q",
|
||||
"to_k.0": "to_k",
|
||||
"to_k.1": "norm_k",
|
||||
"to_v.0": "to_v",
|
||||
"layer1": "net.0.proj",
|
||||
"layer2": "net.2",
|
||||
"proj.1": "proj",
|
||||
"x_embedder": "patch_embed",
|
||||
"extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
|
||||
"blocks.block": rename_transformer_blocks_,
|
||||
"logvar.0.freqs": remove_keys_,
|
||||
"logvar.0.phases": remove_keys_,
|
||||
"logvar.1.weight": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
}
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"t_embedding_norm": "time_embed.norm",
|
||||
"blocks": "transformer_blocks",
|
||||
"adaln_modulation_self_attn.1": "norm1.linear_1",
|
||||
"adaln_modulation_self_attn.2": "norm1.linear_2",
|
||||
"adaln_modulation_cross_attn.1": "norm2.linear_1",
|
||||
"adaln_modulation_cross_attn.2": "norm2.linear_2",
|
||||
"adaln_modulation_mlp.1": "norm3.linear_1",
|
||||
"adaln_modulation_mlp.2": "norm3.linear_2",
|
||||
"self_attn": "attn1",
|
||||
"cross_attn": "attn2",
|
||||
"q_proj": "to_q",
|
||||
"k_proj": "to_k",
|
||||
"v_proj": "to_v",
|
||||
"output_proj": "to_out.0",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
"mlp.layer1": "ff.net.0.proj",
|
||||
"mlp.layer2": "ff.net.2",
|
||||
"x_embedder.proj.1": "patch_embed.proj",
|
||||
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
|
||||
"accum_video_sample_counter": remove_keys_,
|
||||
"accum_image_sample_counter": remove_keys_,
|
||||
"accum_iteration": remove_keys_,
|
||||
"accum_train_in_hours": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
"pos_embedder.dim_spatial_range": remove_keys_,
|
||||
"pos_embedder.dim_temporal_range": remove_keys_,
|
||||
"_extra_state": remove_keys_,
|
||||
}
|
||||
|
||||
PREFIX_KEY = "net."
|
||||
if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
|
||||
else:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
|
||||
state_dict_keys = list(converted_state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
state_dict_keys = list(converted_state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -22,7 +22,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFaceIDImageProjection,
|
||||
@@ -204,7 +203,6 @@ class UNet2DConditionLoadersMixin:
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
if is_lora:
|
||||
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
||||
@@ -213,7 +211,7 @@ class UNet2DConditionLoadersMixin:
|
||||
if is_custom_diffusion:
|
||||
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
||||
elif is_lora:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
|
||||
state_dict=state_dict,
|
||||
unet_identifier_key=self.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
@@ -232,9 +230,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
|
||||
if is_custom_diffusion and _pipeline is not None:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline=_pipeline
|
||||
)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
self.set_attn_processor(attn_processors)
|
||||
@@ -245,10 +241,6 @@ class UNet2DConditionLoadersMixin:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
def _process_custom_diffusion(self, state_dict):
|
||||
@@ -315,7 +307,6 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
||||
|
||||
if len(state_dict_to_be_used) > 0:
|
||||
@@ -365,9 +356,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -400,7 +389,7 @@ class UNet2DConditionLoadersMixin:
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
@@ -54,7 +52,7 @@ def _maybe_expand_lora_scales(
|
||||
weight_for_adapter,
|
||||
blocks_with_transformer,
|
||||
transformer_per_block,
|
||||
model=unet,
|
||||
unet.state_dict(),
|
||||
default_scale=default_scale,
|
||||
)
|
||||
for weight_for_adapter in weight_scales
|
||||
@@ -67,7 +65,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
||||
scales: Union[float, Dict],
|
||||
blocks_with_transformer: Dict[str, int],
|
||||
transformer_per_block: Dict[str, int],
|
||||
model: nn.Module,
|
||||
state_dict: None,
|
||||
default_scale: float = 1.0,
|
||||
):
|
||||
"""
|
||||
@@ -156,7 +154,6 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
||||
|
||||
del scales[updown]
|
||||
|
||||
state_dict = model.state_dict()
|
||||
for layer in scales.keys():
|
||||
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
|
||||
raise ValueError(
|
||||
|
||||
@@ -20,7 +20,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import is_torchvision_available
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
@@ -378,7 +377,7 @@ class CosmosLearnablePositionalEmbed(nn.Module):
|
||||
return (emb / norm).type_as(hidden_states)
|
||||
|
||||
|
||||
class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
|
||||
|
||||
@@ -168,6 +168,7 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
|
||||
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
|
||||
print(f"Set timesteps: {self.timesteps}")
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@@ -150,9 +150,7 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
||||
module.set_scale(adapter_name, 1.0)
|
||||
|
||||
|
||||
def get_peft_kwargs(
|
||||
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
||||
):
|
||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
|
||||
rank_pattern = {}
|
||||
alpha_pattern = {}
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
@@ -182,6 +180,7 @@ def get_peft_kwargs(
|
||||
else:
|
||||
lora_alpha = set(network_alpha_dict.values()).pop()
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
||||
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
||||
# for now we know that the "bias" keys are only associated with `lora_B`.
|
||||
@@ -196,21 +195,6 @@ def get_peft_kwargs(
|
||||
"use_dora": use_dora,
|
||||
"lora_bias": lora_bias,
|
||||
}
|
||||
|
||||
# Example: try load FusionX LoRA into Wan VACE
|
||||
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
|
||||
if exclude_modules:
|
||||
if not is_peft_version(">=", "0.14.0"):
|
||||
msg = """
|
||||
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
|
||||
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
|
||||
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
|
||||
https://github.com/huggingface/diffusers/issues/new
|
||||
"""
|
||||
logger.debug(msg)
|
||||
else:
|
||||
lora_config_kwargs.update({"exclude_modules": exclude_modules})
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
@@ -310,7 +294,11 @@ def check_peft_version(min_version: str) -> None:
|
||||
|
||||
|
||||
def _create_lora_config(
|
||||
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
||||
state_dict,
|
||||
network_alphas,
|
||||
metadata,
|
||||
rank_pattern_dict,
|
||||
is_unet: bool = True,
|
||||
):
|
||||
from peft import LoraConfig
|
||||
|
||||
@@ -318,12 +306,7 @@ def _create_lora_config(
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank_pattern_dict,
|
||||
network_alpha_dict=network_alphas,
|
||||
peft_state_dict=state_dict,
|
||||
is_unet=is_unet,
|
||||
model_state_dict=model_state_dict,
|
||||
adapter_name=adapter_name,
|
||||
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
|
||||
)
|
||||
|
||||
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
|
||||
@@ -388,27 +371,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
|
||||
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
|
||||
"""
|
||||
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
|
||||
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
|
||||
doesn't exist in `peft_state_dict`.
|
||||
"""
|
||||
if model_state_dict is None:
|
||||
return
|
||||
all_modules = set()
|
||||
string_to_replace = f"{adapter_name}." if adapter_name else ""
|
||||
|
||||
for name in model_state_dict.keys():
|
||||
if string_to_replace:
|
||||
name = name.replace(string_to_replace, "")
|
||||
if "." in name:
|
||||
module_name = name.rsplit(".", 1)[0]
|
||||
all_modules.add(module_name)
|
||||
|
||||
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
exclude_modules = list(all_modules - target_modules_set)
|
||||
|
||||
return exclude_modules
|
||||
|
||||
@@ -16,7 +16,6 @@ import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
@@ -29,7 +28,6 @@ from diffusers import (
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
)
|
||||
|
||||
|
||||
@@ -129,13 +127,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@@ -18,17 +18,10 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
@@ -148,13 +141,6 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@@ -24,11 +24,7 @@ from diffusers import (
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@@ -40,7 +39,6 @@ from diffusers.utils.testing_utils import (
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_accelerator,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
@@ -292,21 +290,9 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def _get_exclude_modules(self, pipe):
|
||||
from diffusers.utils.peft_utils import _derive_exclude_modules
|
||||
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
|
||||
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
|
||||
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
|
||||
pipe.unload_lora_weights()
|
||||
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
|
||||
exclude_modules = _derive_exclude_modules(
|
||||
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
|
||||
)
|
||||
return exclude_modules
|
||||
|
||||
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
||||
def check_if_adapters_added_correctly(
|
||||
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
|
||||
):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
@@ -358,7 +344,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -441,7 +427,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -497,7 +483,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -535,7 +521,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
pipe.fuse_lora()
|
||||
# Fusing should still keep the LoRA layers
|
||||
@@ -567,7 +553,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
# unloading should remove the LoRA layers
|
||||
@@ -602,7 +588,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -653,7 +639,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
state_dict = {}
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -704,7 +690,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -747,7 +733,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -788,7 +774,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -832,7 +818,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
|
||||
@@ -870,7 +856,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
# unloading should remove the LoRA layers
|
||||
@@ -906,7 +892,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
|
||||
@@ -1023,7 +1009,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
|
||||
)
|
||||
|
||||
@@ -1045,7 +1031,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
|
||||
)
|
||||
|
||||
@@ -1772,7 +1758,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -1863,7 +1849,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
@@ -1950,7 +1936,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
lora_scale = 0.5
|
||||
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
|
||||
@@ -2132,7 +2118,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe = pipe.to(torch_device, dtype=compute_dtype)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
if storage_dtype is not None:
|
||||
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
@@ -2250,7 +2236,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
|
||||
@@ -2303,7 +2289,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -2322,77 +2308,6 @@ class PeftLoraLoaderMixinTests:
|
||||
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
|
||||
)
|
||||
|
||||
def test_lora_unload_add_adapter(self):
|
||||
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# unload and then add.
|
||||
pipe.unload_lora_weights()
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules(self):
|
||||
"""
|
||||
Test to check if `exclude_modules` works or not. It works in the following way:
|
||||
we first create a pipeline and insert LoRA config into it. We then derive a `set`
|
||||
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
|
||||
state dict.
|
||||
|
||||
We then create a new LoRA config to include the `exclude_modules` and perform tests.
|
||||
"""
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# only supported for `denoiser` now
|
||||
pipe_cp = copy.deepcopy(pipe)
|
||||
pipe_cp, _ = self.add_adapters_to_pipeline(
|
||||
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
|
||||
pipe_cp.to("cpu")
|
||||
del pipe_cp
|
||||
|
||||
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
@@ -2440,73 +2355,3 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
|
||||
|
||||
onload_device = torch_device
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
# Test group offloading with load_lora_weights
|
||||
denoiser.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=1,
|
||||
use_stream=use_stream,
|
||||
)
|
||||
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_1 is not None)
|
||||
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# Test group offloading after removing the lora
|
||||
pipe.unload_lora_weights()
|
||||
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_2 is not None)
|
||||
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
|
||||
|
||||
# Add the lora again and check if group offloading works
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_3 is not None)
|
||||
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
for cls in inspect.getmro(self.__class__):
|
||||
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
|
||||
# Skip this test if it is overwritten by child class. We need to do this because parameterized
|
||||
# materializes the test methods on invocation which cannot be overridden.
|
||||
return
|
||||
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@@ -1350,6 +1350,7 @@ class ModelTesterMixin:
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
|
||||
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
@@ -2018,8 +2019,6 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def tearDown(self):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
@@ -2057,13 +2056,11 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
- hotswap the second adapter
|
||||
- check that the outputs are correct
|
||||
- optionally compile the model
|
||||
- optionally check if recompilations happen on different shapes
|
||||
|
||||
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||
fine.
|
||||
"""
|
||||
different_shapes = self.different_shapes_for_compilation
|
||||
# create 2 adapters with different ranks and alphas
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -2113,30 +2110,19 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
|
||||
model = torch.compile(model, mode="reduce-overhead")
|
||||
|
||||
with torch.inference_mode():
|
||||
# additionally check if dynamic compilation works.
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
@@ -2254,23 +2240,3 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
|
||||
different_shapes_for_compilation = self.different_shapes_for_compilation
|
||||
if different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
|
||||
# variable to represent input sizes that are the same. For more details,
|
||||
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(
|
||||
do_compile=True,
|
||||
rank0=rank0,
|
||||
rank1=rank1,
|
||||
target_modules0=target_modules,
|
||||
)
|
||||
|
||||
@@ -186,10 +186,6 @@ class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
|
||||
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
@@ -1378,6 +1378,7 @@ class PipelineTesterMixin:
|
||||
for component in pipe_fp16.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -1385,20 +1386,17 @@ class PipelineTesterMixin:
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
|
||||
output_fp16 = pipe_fp16(**fp16_inputs)[0]
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff
|
||||
assert max_diff < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
|
||||
@@ -98,14 +98,7 @@ class Base4bitTests(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
|
||||
if not cls.is_deterministic_enabled:
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if not cls.is_deterministic_enabled:
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
prompt_embeds = load_pt(
|
||||
@@ -873,17 +866,15 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class Bnb4BitCompileTests(QuantCompileTests):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
quantization_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
@@ -892,7 +883,5 @@ class Bnb4BitCompileTests(QuantCompileTests):
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(
|
||||
quantization_config=self.quantization_config, use_stream=True
|
||||
)
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
|
||||
|
||||
@@ -99,14 +99,7 @@ class Base8bitTests(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
|
||||
if not cls.is_deterministic_enabled:
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if not cls.is_deterministic_enabled:
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
prompt_embeds = load_pt(
|
||||
@@ -838,13 +831,11 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
class Bnb8BitCompileTests(QuantCompileTests):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
quantization_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
@@ -856,7 +847,7 @@ class Bnb8BitCompileTests(QuantCompileTests):
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
@@ -24,11 +24,7 @@ from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class QuantCompileTests(unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
raise NotImplementedError(
|
||||
"This property should be implemented in the subclass to return the appropriate quantization config."
|
||||
)
|
||||
quantization_config = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@@ -68,9 +64,7 @@ class QuantCompileTests(unittest.TestCase):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_group_offload_leaf(
|
||||
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
|
||||
):
|
||||
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
@@ -78,7 +72,8 @@ class QuantCompileTests(unittest.TestCase):
|
||||
"onload_device": torch.device("cuda"),
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": use_stream,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
pipe.transformer.enable_group_offload(**group_offload_kwargs)
|
||||
pipe.transformer.compile()
|
||||
|
||||
@@ -19,7 +19,6 @@ import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
@@ -30,7 +29,6 @@ from diffusers import (
|
||||
TorchAoConfig,
|
||||
)
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_synchronize,
|
||||
@@ -46,8 +44,6 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_torch_compile_utils import QuantCompileTests
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
@@ -629,53 +625,6 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoCompileTest(QuantCompileTests):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
|
||||
},
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
@unittest.skip(
|
||||
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
"when compiling."
|
||||
)
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
@unittest.skip(
|
||||
"""
|
||||
For `use_stream=False`:
|
||||
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
|
||||
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
|
||||
For `use_stream=True`:
|
||||
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
|
||||
"""
|
||||
)
|
||||
@parameterized.expand([False, True])
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
# For use_stream=False:
|
||||
# If we run group offloading without compilation, we will see:
|
||||
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
|
||||
# When running with compilation, the error ends up being different:
|
||||
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
|
||||
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
|
||||
# Looks like something that will have to be looked into upstream.
|
||||
# for linear layers, weight.tensor_impl shows cuda... but:
|
||||
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
|
||||
|
||||
# For use_stream=True:
|
||||
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
|
||||
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user