Compare commits

..

2 Commits

Author SHA1 Message Date
sayakpaul
1f67e4e7f6 add a test for checking effective custom gc. 2025-01-29 11:04:12 +05:30
Dimitri Barbot
196aef5a6f Fix pipeline dtype unexpected change when using SDXL reference community pipelines in float16 mode (#10670)
Fix pipeline dtype unexpected change when using SDXL reference community pipelines
2025-01-28 10:46:41 -03:00
17 changed files with 139 additions and 95 deletions

View File

@@ -36,8 +36,8 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow pandas peft
python -m uv pip install -e [quality,test]
python -m uv pip install pandas peft
- name: Environment
run: |
python utils/print_env.py

View File

@@ -71,9 +71,9 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install --prerelease=allow pytest-reportlog
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -129,10 +129,10 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install --prerelease=allow pytest-reportlog
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -200,10 +200,10 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install --prerelease=allow pytest-reportlog
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -255,9 +255,9 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -314,9 +314,9 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install --prerelease=allow pytest-reportlog
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -370,9 +370,9 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install --prerelease=allow pytest-reportlog
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -433,9 +433,9 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow -U ${{ matrix.config.backend }}
python -m uv pip install --prerelease=allow pytest-reportlog
python -m uv pip install -e [quality,test]
python -m uv pip install -U ${{ matrix.config.backend }}
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -493,10 +493,10 @@ jobs:
# shell: arch -arch arm64 bash {0}
# run: |
# ${CONDA_RUN} python -m pip install --upgrade pip uv
# ${CONDA_RUN} python -m uv pip install --prerelease=allow -e [quality,test]
# ${CONDA_RUN} python -m uv pip install --prerelease=allow torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
# ${CONDA_RUN} python -m uv pip install --prerelease=allow accelerate@git+https://github.com/huggingface/accelerate
# ${CONDA_RUN} python -m uv pip install --prerelease=allow pytest-reportlog
# ${CONDA_RUN} python -m uv pip install -e [quality,test]
# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
# ${CONDA_RUN} python -m uv pip install pytest-reportlog
# - name: Environment
# shell: arch -arch arm64 bash {0}
# run: |
@@ -549,10 +549,10 @@ jobs:
# shell: arch -arch arm64 bash {0}
# run: |
# ${CONDA_RUN} python -m pip install --upgrade pip uv
# ${CONDA_RUN} python -m uv pip install --prerelease=allow -e [quality,test]
# ${CONDA_RUN} python -m uv pip install --prerelease=allow torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
# ${CONDA_RUN} python -m uv pip install --prerelease=allow accelerate@git+https://github.com/huggingface/accelerate
# ${CONDA_RUN} python -m uv pip install --prerelease=allow pytest-reportlog
# ${CONDA_RUN} python -m uv pip install -e [quality,test]
# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
# ${CONDA_RUN} python -m uv pip install pytest-reportlog
# - name: Environment
# shell: arch -arch arm64 bash {0}
# run: |

View File

@@ -27,8 +27,8 @@ jobs:
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pip install --upgrade pip uv
python -m uv pip install --prerelease=allow -e .
python -m uv pip install --prerelease=allow pytest
python -m uv pip install -e .
python -m uv pip install pytest
- name: Check for soft dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"

View File

@@ -27,11 +27,11 @@ jobs:
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pip install --upgrade pip uv
python -m uv pip install --prerelease=allow -e .
python -m uv pip install --prerelease=allow "jax[cpu]>=0.2.16,!=0.3.2"
python -m uv pip install --prerelease=allow "flax>=0.4.1"
python -m uv pip install --prerelease=allow "jaxlib>=0.1.65"
python -m uv pip install --prerelease=allow pytest
python -m uv pip install -e .
python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
python -m uv pip install "flax>=0.4.1"
python -m uv pip install "jaxlib>=0.1.65"
python -m uv pip install pytest
- name: Check for soft dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"

View File

@@ -34,7 +34,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install -e [quality,test]
- name: Environment
run: |
python utils/print_env.py

View File

@@ -119,8 +119,8 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow accelerate
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate
- name: Environment
run: |
@@ -158,7 +158,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow peft timm
python -m uv pip install peft timm
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
@@ -208,7 +208,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install -e [quality,test]
- name: Environment
run: |
@@ -262,12 +262,12 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install -e [quality,test]
# TODO (sayakpaul, DN6): revisit `--no-deps`
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
python -m uv pip install --prerelease=allow -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
python -m uv pip install --prerelease=allow -U tokenizers
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
python -m uv pip install -U tokenizers
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |

View File

@@ -27,9 +27,9 @@ jobs:
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pip install --upgrade pip uv
python -m uv pip install --prerelease=allow -e .
python -m uv pip install --prerelease=allow torch torchvision torchaudio
python -m uv pip install --prerelease=allow pytest
python -m uv pip install -e .
python -m uv pip install torch torchvision torchaudio
python -m uv pip install pytest
- name: Check for soft dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"

View File

@@ -35,7 +35,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install -e [quality,test]
- name: Environment
run: |
python utils/print_env.py
@@ -76,8 +76,8 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
@@ -127,9 +127,9 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -178,8 +178,8 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -226,8 +226,8 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -277,7 +277,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test,training]
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
@@ -320,7 +320,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test,training]
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
@@ -363,7 +363,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test,training]
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
@@ -375,7 +375,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow timm
python -m uv pip install timm
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
- name: Failure short reports

View File

@@ -71,7 +71,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install -e [quality,test]
- name: Environment
run: |
@@ -109,7 +109,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow peft timm
python -m uv pip install peft timm
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples

View File

@@ -46,10 +46,10 @@ jobs:
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pip install --upgrade pip uv
${CONDA_RUN} python -m uv pip install --prerelease=allow -e ".[quality,test]"
${CONDA_RUN} python -m uv pip install --prerelease=allow torch torchvision torchaudio
${CONDA_RUN} python -m uv pip install --prerelease=allow accelerate@git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m uv pip install --prerelease=allow transformers --upgrade
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m uv pip install transformers --upgrade
- name: Environment
shell: arch -arch arm64 bash {0}

View File

@@ -33,7 +33,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install -e [quality,test]
- name: Environment
run: |
python utils/print_env.py
@@ -74,8 +74,8 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
@@ -125,9 +125,9 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -176,9 +176,9 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -232,8 +232,8 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -280,8 +280,8 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install --prerelease=allow -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -331,7 +331,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test,training]
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
@@ -374,7 +374,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test,training]
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
@@ -417,7 +417,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test,training]
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
@@ -429,7 +429,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow timm
python -m uv pip install timm
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
- name: Failure short reports

View File

@@ -64,8 +64,8 @@ jobs:
- name: Install pytest
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install --prerelease=allow -e [quality,test]
python -m uv pip install --prerelease=allow peft
python -m uv pip install -e [quality,test]
python -m uv pip install peft
- name: Run tests
env:

View File

@@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
@@ -223,6 +224,11 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return ref_image_latents
def prepare_ref_image(

View File

@@ -139,7 +139,8 @@ def retrieve_timesteps(
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
@@ -169,6 +170,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return ref_image_latents
def prepare_ref_image(

View File

@@ -101,7 +101,7 @@ _deps = [
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub==v0.28.0.rc0",
"huggingface-hub>=0.27.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",

View File

@@ -9,7 +9,7 @@ deps = {
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub==v0.28.0.rc0",
"huggingface-hub": "huggingface-hub>=0.27.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",

View File

@@ -966,6 +966,38 @@ class ModelTesterMixin:
assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
@require_torch_accelerator_with_training
def test_apply_gradient_checkpointing_every_n_block(self, block_num=2):
# Skip test if model does not support gradient checkpointing
if not self.model_class._supports_gradient_checkpointing:
return
# For now, we only test for transformer models.
if "transformer" not in self.model_class.__name__.lower():
return
# enable deterministic behavior for gradient checkpointing
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict)
def gradient_checkpointing_func(model, *args):
if model.layer_index % block_num == 0:
return torch.utils.checkpoint.checkpoint(model.__call__, *args, use_reentrant=False)
return model(*args)
if getattr(model, "transformer_blocks", None) is not None:
for index, layer in enumerate(model.transformer_blocks):
layer.layer_index = index
model.enable_gradient_checkpointing(gradient_checkpointing_func)
assert model.training
for index, layer in enumerate(model.transformer_blocks):
if model.layer_index % block_num == 0:
assert layer.is_gradient_checkpointing
else:
assert not layer.is_gradient_checkpointing
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0