mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-01 01:11:12 +08:00
Compare commits
10 Commits
adv-flux
...
sage-kerne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23c173ea58 | ||
|
|
8abc7aeb71 | ||
|
|
693d8a3a52 | ||
|
|
a9df12ab45 | ||
|
|
a519272d97 | ||
|
|
3688c9d443 | ||
|
|
d3441340b9 | ||
|
|
18c3e8ee0c | ||
|
|
f630dab8a2 | ||
|
|
e9ea1c5b2c |
5
.github/workflows/benchmark.yml
vendored
5
.github/workflows/benchmark.yml
vendored
@@ -38,9 +38,8 @@ jobs:
|
||||
run: |
|
||||
apt update
|
||||
apt install -y libpq-dev postgresql-client
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -r benchmarks/requirements.txt
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -r benchmarks/requirements.txt
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade huggingface_hub
|
||||
|
||||
# Check secret is set
|
||||
|
||||
91
.github/workflows/nightly_tests.yml
vendored
91
.github/workflows/nightly_tests.yml
vendored
@@ -71,10 +71,9 @@ jobs:
|
||||
run: nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -84,7 +83,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
|
||||
@@ -124,11 +123,10 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
@@ -139,7 +137,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_${{ matrix.module }}_cuda \
|
||||
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
|
||||
@@ -152,7 +150,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v --make-reports=examples_torch_cuda \
|
||||
--report-log=examples_torch_cuda.log \
|
||||
examples/
|
||||
@@ -191,8 +189,7 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
uv pip install -e ".[quality,training]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -201,7 +198,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -232,11 +229,10 @@ jobs:
|
||||
run: nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -247,7 +243,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
BIG_GPU_MEMORY: 40
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-m "big_accelerator" \
|
||||
--make-reports=tests_big_gpu_torch_cuda \
|
||||
--report-log=tests_big_gpu_torch_cuda.log \
|
||||
@@ -282,10 +278,9 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -297,7 +292,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_minimum_version_cuda \
|
||||
tests/models/test_modeling_common.py \
|
||||
@@ -357,13 +352,12 @@ jobs:
|
||||
run: nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -U ${{ matrix.config.backend }}
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -U ${{ matrix.config.backend }}
|
||||
if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
|
||||
python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
|
||||
uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
|
||||
fi
|
||||
python -m uv pip install pytest-reportlog
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -374,7 +368,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
BIG_GPU_MEMORY: 40
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.backend }}_torch_cuda \
|
||||
--report-log=tests_${{ matrix.config.backend }}_torch_cuda.log \
|
||||
tests/quantization/${{ matrix.config.test_location }}
|
||||
@@ -409,10 +403,9 @@ jobs:
|
||||
run: nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -U bitsandbytes optimum_quanto
|
||||
python -m uv pip install pytest-reportlog
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -U bitsandbytes optimum_quanto
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -423,7 +416,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
BIG_GPU_MEMORY: 40
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_pipeline_level_quant_torch_cuda \
|
||||
--report-log=tests_pipeline_level_quant_torch_cuda.log \
|
||||
tests/quantization/test_pipeline_level_quantization.py
|
||||
@@ -523,11 +516,11 @@ jobs:
|
||||
# - name: Install dependencies
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
# ${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
# ${CONDA_RUN} python -m uv pip install pytest-reportlog
|
||||
# ${CONDA_RUN} pip install --upgrade pip uv
|
||||
# ${CONDA_RUN} uv pip install -e ".[quality]"
|
||||
# ${CONDA_RUN} uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# ${CONDA_RUN} uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
# ${CONDA_RUN} uv pip install pytest-reportlog
|
||||
# - name: Environment
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
@@ -538,7 +531,7 @@ jobs:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
# tests/
|
||||
# - name: Failure short reports
|
||||
@@ -579,11 +572,11 @@ jobs:
|
||||
# - name: Install dependencies
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
# ${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
# ${CONDA_RUN} python -m uv pip install pytest-reportlog
|
||||
# ${CONDA_RUN} pip install --upgrade pip uv
|
||||
# ${CONDA_RUN} uv pip install -e ".[quality]"
|
||||
# ${CONDA_RUN} uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# ${CONDA_RUN} uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
# ${CONDA_RUN} uv pip install pytest-reportlog
|
||||
# - name: Environment
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
@@ -594,7 +587,7 @@ jobs:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
# tests/
|
||||
# - name: Failure short reports
|
||||
|
||||
9
.github/workflows/pr_dependency_test.yml
vendored
9
.github/workflows/pr_dependency_test.yml
vendored
@@ -25,11 +25,8 @@ jobs:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pip install --upgrade pip uv
|
||||
python -m uv pip install -e .
|
||||
python -m uv pip install pytest
|
||||
pip install -e .
|
||||
pip install pytest
|
||||
- name: Check for soft dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pytest tests/others/test_dependencies.py
|
||||
pytest tests/others/test_dependencies.py
|
||||
|
||||
15
.github/workflows/pr_modular_tests.yml
vendored
15
.github/workflows/pr_modular_tests.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
@@ -108,21 +108,18 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
19
.github/workflows/pr_test_fetcher.yml
vendored
19
.github/workflows/pr_test_fetcher.yml
vendored
@@ -33,8 +33,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
uv pip install -e ".[quality]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -90,19 +89,16 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pip install -e [quality,test]
|
||||
python -m pip install accelerate
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install accelerate
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run all selected tests on CPU
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
|
||||
pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
@@ -148,19 +144,16 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pip install -e [quality,test]
|
||||
pip install -e [quality]
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
|
||||
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
HUGGINGFACE_CO_STAGING=true python -m pytest \
|
||||
HUGGINGFACE_CO_STAGING=true pytest \
|
||||
-m "is_staging_test" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests
|
||||
|
||||
48
.github/workflows/pr_tests.yml
vendored
48
.github/workflows/pr_tests.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
@@ -114,21 +114,18 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/pipelines
|
||||
@@ -136,8 +133,7 @@ jobs:
|
||||
- name: Run fast PyTorch Model Scheduler CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_models' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and not Dependency" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
@@ -145,9 +141,8 @@ jobs:
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install peft timm
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
uv pip install ".[training]"
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples
|
||||
|
||||
@@ -195,19 +190,16 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
uv pip install -e ".[quality]"
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
|
||||
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
HUGGINGFACE_CO_STAGING=true python -m pytest \
|
||||
HUGGINGFACE_CO_STAGING=true pytest \
|
||||
-m "is_staging_test" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests
|
||||
@@ -249,27 +241,24 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
uv pip install -e ".[quality]"
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
python -m uv pip install -U tokenizers
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
uv pip install -U tokenizers
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch LoRA tests with PEFT
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_peft_main \
|
||||
tests/lora/
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_models_lora_peft_main \
|
||||
tests/models/ -k "lora"
|
||||
@@ -286,3 +275,4 @@ jobs:
|
||||
with:
|
||||
name: pr_main_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
42
.github/workflows/pr_tests_gpu.yml
vendored
42
.github/workflows/pr_tests_gpu.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
@@ -88,8 +88,7 @@ jobs:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
uv pip install -e ".[quality]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -130,10 +129,9 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -152,13 +150,13 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
else
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and $pattern" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
@@ -200,11 +198,10 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -225,10 +222,10 @@ jobs:
|
||||
run: |
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
if [ -z "$pattern" ]; then
|
||||
python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
|
||||
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
else
|
||||
python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
|
||||
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
fi
|
||||
|
||||
@@ -265,22 +262,19 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip install -e ".[quality,training]"
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install timm
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
10
.github/workflows/pr_torch_dependency_test.yml
vendored
10
.github/workflows/pr_torch_dependency_test.yml
vendored
@@ -25,12 +25,8 @@ jobs:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pip install --upgrade pip uv
|
||||
python -m uv pip install -e .
|
||||
python -m uv pip install torch torchvision torchaudio
|
||||
python -m uv pip install pytest
|
||||
pip install -e .
|
||||
pip install torch torchvision torchaudio pytest
|
||||
- name: Check for soft dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pytest tests/others/test_dependencies.py
|
||||
pytest tests/others/test_dependencies.py
|
||||
|
||||
38
.github/workflows/push_tests.yml
vendored
38
.github/workflows/push_tests.yml
vendored
@@ -34,8 +34,7 @@ jobs:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
uv pip install -e ".[quality]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -75,9 +74,8 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -87,7 +85,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
@@ -126,10 +124,9 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -141,7 +138,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }} \
|
||||
tests/${{ matrix.module }}
|
||||
@@ -180,8 +177,7 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
uv pip install -e ".[quality,training]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -190,7 +186,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -223,8 +219,7 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
uv pip install -e ".[quality,training]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -232,7 +227,7 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
@@ -264,21 +259,18 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
uv pip install -e ".[quality,training]"
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install timm
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
12
.github/workflows/push_tests_fast.yml
vendored
12
.github/workflows/push_tests_fast.yml
vendored
@@ -60,19 +60,16 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
uv pip install -e ".[quality]"
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
@@ -80,9 +77,8 @@ jobs:
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install peft timm
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
uv pip install ".[training]"
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples
|
||||
|
||||
|
||||
47
.github/workflows/release_tests_fast.yml
vendored
47
.github/workflows/release_tests_fast.yml
vendored
@@ -32,8 +32,7 @@ jobs:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
uv pip install -e ".[quality]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -73,9 +72,8 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -85,7 +83,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
@@ -124,10 +122,9 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -139,7 +136,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_${{ matrix.module }}_cuda \
|
||||
tests/${{ matrix.module }}
|
||||
@@ -175,10 +172,9 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -190,7 +186,7 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_minimum_cuda \
|
||||
tests/models/test_modeling_common.py \
|
||||
@@ -235,8 +231,7 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
uv pip install -e ".[quality,training]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -245,7 +240,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -278,8 +273,7 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
uv pip install -e ".[quality,training]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -287,7 +281,7 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
@@ -321,21 +315,18 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
uv pip install -e ".[quality,training]"
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install timm
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
5
.github/workflows/run_tests_from_a_pr.yml
vendored
5
.github/workflows/run_tests_from_a_pr.yml
vendored
@@ -63,9 +63,8 @@ jobs:
|
||||
|
||||
- name: Install pytest
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft
|
||||
|
||||
- name: Run tests
|
||||
env:
|
||||
|
||||
@@ -11,8 +11,11 @@ RUN apt-get -y update && apt-get install -y bash \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libglib2.0-0 \
|
||||
libsndfile1-dev \
|
||||
libgl1
|
||||
libgl1 \
|
||||
zip \
|
||||
wget
|
||||
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ RUN apt-get -y update && apt-get install -y bash \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libglib2.0-0 \
|
||||
libsndfile1-dev \
|
||||
libgl1
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ RUN apt install -y bash \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libglib2.0-0 \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3 \
|
||||
|
||||
@@ -19,6 +19,7 @@ RUN apt install -y bash \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libglib2.0-0 \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3 \
|
||||
|
||||
@@ -16,6 +16,7 @@ RUN apt install -y bash \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libglib2.0-0 \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3 \
|
||||
|
||||
@@ -75,7 +75,7 @@ The following is a summary of the recommended checkpoints, all of which produce
|
||||
| [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | Depth | Affine-invariant depth prediction assigns each pixel a value between 0 (near plane) and 1 (far plane), with both planes determined by the model during inference. |
|
||||
| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | The surface normals predictions are unit-length 3D vectors in the screen space camera, with values in the range from -1 to 1. |
|
||||
| [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1) | Intrinsics | InteriorVerse decomposition is comprised of Albedo and two BRDF material properties: Roughness and Metallicity. |
|
||||
| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | Intrinsics | HyperSim decomposition of an image  \\(I\\)  is comprised of Albedo  \\(A\\), Diffuse shading  \\(S\\), and Non-diffuse residual  \\(R\\):  \\(I = A*S+R\\). |
|
||||
| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | Intrinsics | HyperSim decomposition of an image $I$ is comprised of Albedo $A$, Diffuse shading $S$, and Non-diffuse residual $R$: $I = A*S+R$. |
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff
|
||||
|
||||
3
setup.py
3
setup.py
@@ -145,6 +145,7 @@ _deps = [
|
||||
"black",
|
||||
"phonemizer",
|
||||
"opencv-python",
|
||||
"timm",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -218,7 +219,7 @@ class DepsTableUpdateCommand(Command):
|
||||
extras = {}
|
||||
extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder")
|
||||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft", "timm")
|
||||
extras["test"] = deps_list(
|
||||
"compel",
|
||||
"GitPython",
|
||||
|
||||
@@ -386,6 +386,8 @@ else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"FluxAutoBlocks",
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
"FluxModularPipeline",
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageEditAutoBlocks",
|
||||
@@ -1050,6 +1052,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextModularPipeline,
|
||||
FluxModularPipeline,
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
|
||||
@@ -52,4 +52,5 @@ deps = {
|
||||
"black": "black",
|
||||
"phonemizer": "phonemizer",
|
||||
"opencv-python": "opencv-python",
|
||||
"timm": "timm",
|
||||
}
|
||||
|
||||
@@ -17,7 +17,8 @@ import functools
|
||||
import inspect
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -83,12 +84,20 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise ImportError(
|
||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||
)
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub
|
||||
from ..utils.sage_utils import _get_sage_attn_fn_for_device
|
||||
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3)
|
||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||
|
||||
sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE)
|
||||
sage_fn_with_kwargs = _get_sage_attn_fn_for_device()
|
||||
sage_attn_func_hub = getattr(sage_interface_hub, sage_fn_with_kwargs["func"])
|
||||
sage_attn_func_hub = partial(sage_attn_func_hub, **sage_fn_with_kwargs["kwargs"])
|
||||
|
||||
else:
|
||||
flash_attn_3_func_hub = None
|
||||
sage_attn_func_hub = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
@@ -162,10 +171,6 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
# - CP with sage attention, flex, xformers, other missing backends
|
||||
# - Add support for normal and CP training with backends that don't support it yet
|
||||
|
||||
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
||||
|
||||
|
||||
class AttentionBackendName(str, Enum):
|
||||
# EAGER = "eager"
|
||||
@@ -190,6 +195,7 @@ class AttentionBackendName(str, Enum):
|
||||
|
||||
# `sageattention`
|
||||
SAGE = "sage"
|
||||
SAGE_HUB = "sage_hub"
|
||||
SAGE_VARLEN = "sage_varlen"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
||||
@@ -1756,6 +1762,31 @@ def _sage_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sage_attn_func_hub(q=query, k=key, v=value)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
raise NotImplementedError("SAGE attention doesn't yet support parallelism.")
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -46,7 +46,12 @@ else:
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageModularPipeline",
|
||||
@@ -65,7 +70,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxModularPipeline
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
@@ -25,14 +25,18 @@ else:
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"AUTO_BLOCKS_KONTEXT",
|
||||
"FLUX_KONTEXT_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"FluxAutoBeforeDenoiseStep",
|
||||
"FluxAutoBlocks",
|
||||
"FluxAutoBlocks",
|
||||
"FluxAutoDecodeStep",
|
||||
"FluxAutoDenoiseStep",
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextAutoDenoiseStep",
|
||||
"FluxKontextBeforeDenoiseStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
|
||||
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -45,13 +49,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
AUTO_BLOCKS_KONTEXT,
|
||||
FLUX_KONTEXT_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
FluxAutoBeforeDenoiseStep,
|
||||
FluxAutoBlocks,
|
||||
FluxAutoDecodeStep,
|
||||
FluxAutoDenoiseStep,
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextAutoDenoiseStep,
|
||||
FluxKontextBeforeDenoiseStep,
|
||||
)
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -118,15 +118,6 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# TODO: align this with Qwen patchifier
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
@@ -398,16 +389,15 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
# TODO: move packing latents code to a patchifier
|
||||
# TODO: move packing latents code to a patchifier similar to Qwen
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
@@ -557,3 +547,73 @@ class FluxRoPEInputsStep(ModularPipelineBlocks):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="image_height"),
|
||||
InputParam(name="image_width"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="prompt_embeds"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="img_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the image latents, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device, dtype = prompt_embeds.device, prompt_embeds.dtype
|
||||
block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
|
||||
device=prompt_embeds.device, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
img_ids = None
|
||||
if (
|
||||
getattr(block_state, "image_height", None) is not None
|
||||
and getattr(block_state, "image_width", None) is not None
|
||||
):
|
||||
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
|
||||
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
img_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
None, image_latent_height // 2, image_latent_width // 2, device, dtype
|
||||
)
|
||||
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
|
||||
img_ids[..., 0] = 1
|
||||
|
||||
height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
|
||||
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
|
||||
|
||||
if img_ids is not None:
|
||||
latent_ids = torch.cat([latent_ids, img_ids], dim=0)
|
||||
|
||||
block_state.img_ids = latent_ids
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -109,6 +109,96 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class FluxKontextLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("transformer", FluxTransformer2DModel)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoise the latents for Flux Kontext. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `FluxDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Image latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Prompt embeddings",
|
||||
),
|
||||
InputParam(
|
||||
"pooled_prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Pooled prompt embeddings",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from text sequence needed for RoPE",
|
||||
),
|
||||
InputParam(
|
||||
"img_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from latent sequence needed for RoPE",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents
|
||||
image_latents = block_state.image_latents
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=block_state.guidance,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
pooled_projections=block_state.pooled_prompt_embeds,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
txt_ids=block_state.txt_ids,
|
||||
img_ids=block_state.img_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
block_state.noise_pred = noise_pred
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@@ -221,3 +311,20 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
|
||||
" - `FluxLoopAfterDenoiser`\n"
|
||||
"This block supports both text2image and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper):
|
||||
model_name = "flux-kontext"
|
||||
block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `FluxKontextLoopDenoiser`\n"
|
||||
" - `FluxLoopAfterDenoiser`\n"
|
||||
"This block supports both text2image and img2img tasks."
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
@@ -83,11 +83,11 @@ def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.G
|
||||
|
||||
|
||||
class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "Flux"
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Preprocess step. Resizing is needed in Flux Kontext (will be implemented later.)"
|
||||
return "Image Preprocess step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -106,9 +106,7 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="processed_image"),
|
||||
]
|
||||
return [OutputParam(name="processed_image")]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
@@ -142,13 +140,80 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
def __init__(self, _auto_resize=True):
|
||||
self._auto_resize = _auto_resize
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Image preprocess step for Flux Kontext. The preprocessed image goes to the VAE.\n"
|
||||
"Kontext works as a T2I model, too, in case no input image is provided."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("image")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam(name="processed_image")]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState):
|
||||
from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
|
||||
|
||||
block_state = self.get_block_state(state)
|
||||
images = block_state.image
|
||||
|
||||
if images is None:
|
||||
block_state.processed_image = None
|
||||
|
||||
else:
|
||||
multiple_of = components.image_processor.config.vae_scale_factor
|
||||
|
||||
if not is_valid_image_imagelist(images):
|
||||
raise ValueError(f"Images must be image or list of images but are {type(images)}")
|
||||
|
||||
if is_valid_image(images):
|
||||
images = [images]
|
||||
|
||||
img = images[0]
|
||||
image_height, image_width = components.image_processor.get_default_height_width(img)
|
||||
aspect_ratio = image_width / image_height
|
||||
if self._auto_resize:
|
||||
# Kontext is trained on specific resolutions, using one of them is recommended
|
||||
_, image_width, image_height = min(
|
||||
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
||||
)
|
||||
image_width = image_width // multiple_of * multiple_of
|
||||
image_height = image_height // multiple_of * multiple_of
|
||||
images = components.image_processor.resize(images, image_height, image_width)
|
||||
block_state.processed_image = components.image_processor.preprocess(images, image_height, image_width)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_name: str = "processed_image",
|
||||
output_name: str = "image_latents",
|
||||
self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"
|
||||
):
|
||||
"""Initialize a VAE encoder step for converting images to latent representations.
|
||||
|
||||
@@ -160,6 +225,7 @@ class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
Examples: "processed_image" or "processed_control_image"
|
||||
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
|
||||
Examples: "image_latents" or "control_image_latents"
|
||||
sample_mode (str, optional): Sampling mode to be used.
|
||||
|
||||
Examples:
|
||||
# Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep()
|
||||
@@ -170,6 +236,7 @@ class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
"""
|
||||
self._image_input_name = input_name
|
||||
self._image_latents_output_name = output_name
|
||||
self.sample_mode = sample_mode
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@@ -183,7 +250,7 @@ class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
|
||||
inputs = [InputParam(self._image_input_name), InputParam("generator")]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
@@ -199,16 +266,20 @@ class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image = getattr(block_state, self._image_input_name)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
# Encode image into latents
|
||||
image_latents = encode_vae_image(image=image, vae=components.vae, generator=block_state.generator)
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
if image is None:
|
||||
setattr(block_state, self._image_latents_output_name, None)
|
||||
else:
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
# Encode image into latents
|
||||
image_latents = encode_vae_image(
|
||||
image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode
|
||||
)
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import List
|
||||
import torch
|
||||
|
||||
from ...pipelines import FluxPipeline
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import InputParam, OutputParam
|
||||
|
||||
@@ -25,6 +26,9 @@ from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_t
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FluxTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@@ -234,3 +238,122 @@ class FluxInputsDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents
|
||||
# Unlike the `FluxInputsDynamicStep`, we don't overwrite the `block.height` and `block.width`
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
|
||||
if not hasattr(block_state, "image_height"):
|
||||
block_state.image_height = height
|
||||
if not hasattr(block_state, "image_width"):
|
||||
block_state.image_width = width
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
# TODO: Implement patchifier for Flux.
|
||||
latent_height, latent_width = image_latent_tensor.shape[2:]
|
||||
image_latent_tensor = FluxPipeline._pack_latents(
|
||||
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
|
||||
)
|
||||
|
||||
# 3. Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxKontextSetResolutionStep(ModularPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
def description(self):
|
||||
return (
|
||||
"Determines the height and width to be used during the subsequent computations.\n"
|
||||
"It should always be placed _before_ the latent preparation step."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="max_area", type_hint=int, default=1024**2),
|
||||
]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="height", type_hint=int, description="The height of the initial noisy latents"),
|
||||
OutputParam(name="width", type_hint=int, description="The width of the initial noisy latents"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
self.check_inputs(height, width, components.vae_scale_factor)
|
||||
|
||||
original_height, original_width = height, width
|
||||
max_area = block_state.max_area
|
||||
aspect_ratio = width / height
|
||||
width = round((max_area * aspect_ratio) ** 0.5)
|
||||
height = round((max_area / aspect_ratio) ** 0.5)
|
||||
|
||||
multiple_of = components.vae_scale_factor * 2
|
||||
width = width // multiple_of * multiple_of
|
||||
height = height // multiple_of * multiple_of
|
||||
|
||||
if height != original_height or width != original_width:
|
||||
logger.warning(
|
||||
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
|
||||
)
|
||||
|
||||
block_state.height = height
|
||||
block_state.width = width
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -18,14 +18,25 @@ from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
FluxImg2ImgPrepareLatentsStep,
|
||||
FluxImg2ImgSetTimestepsStep,
|
||||
FluxKontextRoPEInputsStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxRoPEInputsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxDenoiseStep
|
||||
from .encoders import FluxProcessImagesInputStep, FluxTextEncoderStep, FluxVaeEncoderDynamicStep
|
||||
from .inputs import FluxInputsDynamicStep, FluxTextInputStep
|
||||
from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep
|
||||
from .encoders import (
|
||||
FluxKontextProcessImagesInputStep,
|
||||
FluxProcessImagesInputStep,
|
||||
FluxTextEncoderStep,
|
||||
FluxVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import (
|
||||
FluxInputsDynamicStep,
|
||||
FluxKontextInputsDynamicStep,
|
||||
FluxKontextSetResolutionStep,
|
||||
FluxTextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -33,10 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# vae encoder (run before before_denoise)
|
||||
FluxImg2ImgVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", FluxProcessImagesInputStep()),
|
||||
("encode", FluxVaeEncoderDynamicStep()),
|
||||
]
|
||||
[("preprocess", FluxProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep())]
|
||||
)
|
||||
|
||||
|
||||
@@ -66,6 +74,39 @@ class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# Flux Kontext vae encoder (run before before_denoise)
|
||||
|
||||
FluxKontextVaeEncoderBlocks = InsertableDict(
|
||||
[("preprocess", FluxKontextProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax"))]
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = FluxKontextVaeEncoderBlocks.values()
|
||||
block_names = FluxKontextVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxKontextVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxKontextVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img
|
||||
FluxBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
@@ -107,6 +148,7 @@ class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
|
||||
# before_denoise: all task (text2img, img2img)
|
||||
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
@@ -121,6 +163,44 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: FluxKontext
|
||||
|
||||
FluxKontextBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxSetTimestepsStep()),
|
||||
("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = FluxKontextBeforeDenoiseBlocks.values()
|
||||
block_names = FluxKontextBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step\n"
|
||||
"for img2img/text2img task for Flux Kontext."
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxKontextBeforeDenoiseStep, FluxBeforeDenoiseStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# denoise: text2image
|
||||
class FluxAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxDenoiseStep]
|
||||
@@ -136,6 +216,23 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# denoise: Flux Kontext
|
||||
|
||||
|
||||
class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxKontextDenoiseStep]
|
||||
block_names = ["denoise"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents for Flux Kontext. "
|
||||
"This is a auto pipeline block that works for text2image and img2img tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
# decode: all task (text2img, img2img)
|
||||
class FluxAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxDecodeStep]
|
||||
@@ -165,7 +262,7 @@ class FluxImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
class FluxImageAutoInputStep(AutoPipelineBlocks):
|
||||
class FluxAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
@@ -180,16 +277,59 @@ class FluxImageAutoInputStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# inputs: Flux Kontext
|
||||
|
||||
FluxKontextBlocks = InsertableDict(
|
||||
[
|
||||
("set_resolution", FluxKontextSetResolutionStep()),
|
||||
("text_inputs", FluxTextInputStep()),
|
||||
("additional_inputs", FluxKontextInputsDynamicStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextInputStep(SequentialPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
block_classes = FluxKontextBlocks.values()
|
||||
block_names = FluxKontextBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxKontextInputStep, FluxTextInputStep]
|
||||
# block_classes = [FluxKontextInputStep]
|
||||
block_names = ["img2img", "text2img"]
|
||||
# block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
# block_trigger_inputs = ["image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
|
||||
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
|
||||
+ " - `FluxKontextInputStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
|
||||
)
|
||||
|
||||
|
||||
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
block_classes = [FluxImageAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
|
||||
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `FluxImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
|
||||
@@ -198,6 +338,24 @@ class FluxCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
|
||||
)
|
||||
|
||||
|
||||
# Auto blocks (text2image and img2img)
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
@@ -208,6 +366,15 @@ AUTO_BLOCKS = InsertableDict(
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS_KONTEXT = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("image_encoder", FluxKontextAutoVaeEncoderStep()),
|
||||
("denoise", FluxKontextCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
@@ -224,6 +391,13 @@ class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextAutoBlocks(FluxAutoBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = AUTO_BLOCKS_KONTEXT.values()
|
||||
block_names = AUTO_BLOCKS_KONTEXT.keys()
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
@@ -250,4 +424,23 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
]
|
||||
)
|
||||
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
FLUX_KONTEXT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxVaeEncoderDynamicStep(sample_mode="argmax")),
|
||||
("input", FluxKontextInputStep()),
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxSetTimestepsStep()),
|
||||
("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
|
||||
("denoise", FluxKontextDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"img2img": IMAGE2IMAGE_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
"auto_kontext": AUTO_BLOCKS_KONTEXT,
|
||||
"kontext": FLUX_KONTEXT_BLOCKS,
|
||||
}
|
||||
|
||||
@@ -55,3 +55,13 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversion
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
class FluxKontextModularPipeline(FluxModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Flux Kontext.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "FluxKontextAutoBlocks"
|
||||
|
||||
@@ -57,6 +57,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||
("wan", "WanModularPipeline"),
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
|
||||
@@ -86,15 +86,14 @@ class MarigoldDepthOutput(BaseOutput):
|
||||
|
||||
Args:
|
||||
prediction (`np.ndarray`, `torch.Tensor`):
|
||||
Predicted depth maps with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times
|
||||
width$ for `torch.Tensor` or $numimages \times height \times width \times 1$ for `np.ndarray`.
|
||||
Predicted depth maps with values in the range [0, 1]. The shape is `numimages × 1 × height × width` for
|
||||
`torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
|
||||
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
|
||||
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
|
||||
\times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
|
||||
for `np.ndarray`.
|
||||
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
|
||||
height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
|
||||
latent (`None`, `torch.Tensor`):
|
||||
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
|
||||
The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
|
||||
The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
|
||||
"""
|
||||
|
||||
prediction: Union[np.ndarray, torch.Tensor]
|
||||
|
||||
@@ -99,17 +99,17 @@ class MarigoldIntrinsicsOutput(BaseOutput):
|
||||
|
||||
Args:
|
||||
prediction (`np.ndarray`, `torch.Tensor`):
|
||||
Predicted image intrinsics with values in the range [0, 1]. The shape is $(numimages * numtargets) \times 3
|
||||
\times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times height \times width
|
||||
\times 3$ for `np.ndarray`, where `numtargets` corresponds to the number of predicted target modalities of
|
||||
the intrinsic image decomposition.
|
||||
Predicted image intrinsics with values in the range [0, 1]. The shape is `(numimages * numtargets) × 3 ×
|
||||
height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for `np.ndarray`,
|
||||
where `numtargets` corresponds to the number of predicted target modalities of the intrinsic image
|
||||
decomposition.
|
||||
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
|
||||
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $(numimages *
|
||||
numtargets) \times 3 \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times
|
||||
height \times width \times 3$ for `np.ndarray`.
|
||||
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `(numimages *
|
||||
numtargets) × 3 × height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for
|
||||
`np.ndarray`.
|
||||
latent (`None`, `torch.Tensor`):
|
||||
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
|
||||
The shape is $(numimages * numensemble) \times (numtargets * 4) \times latentheight \times latentwidth$.
|
||||
The shape is `(numimages * numensemble) × (numtargets * 4) × latentheight × latentwidth`.
|
||||
"""
|
||||
|
||||
prediction: Union[np.ndarray, torch.Tensor]
|
||||
|
||||
@@ -81,15 +81,14 @@ class MarigoldNormalsOutput(BaseOutput):
|
||||
|
||||
Args:
|
||||
prediction (`np.ndarray`, `torch.Tensor`):
|
||||
Predicted normals with values in the range [-1, 1]. The shape is $numimages \times 3 \times height \times
|
||||
width$ for `torch.Tensor` or $numimages \times height \times width \times 3$ for `np.ndarray`.
|
||||
Predicted normals with values in the range [-1, 1]. The shape is `numimages × 3 × height × width` for
|
||||
`torch.Tensor` or `numimages × height × width × 3` for `np.ndarray`.
|
||||
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
|
||||
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
|
||||
\times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
|
||||
for `np.ndarray`.
|
||||
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
|
||||
height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
|
||||
latent (`None`, `torch.Tensor`):
|
||||
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
|
||||
The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
|
||||
The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
|
||||
"""
|
||||
|
||||
prediction: Union[np.ndarray, torch.Tensor]
|
||||
|
||||
@@ -17,6 +17,36 @@ class FluxAutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxKontextAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxKontextModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -6,18 +6,25 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
||||
_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention"
|
||||
_KERNEL_REVISION = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
_DEFAULT_HUB_ID_FA3: "fake-ops-return-probs",
|
||||
_DEFAULT_HUB_ID_SAGE: "compile",
|
||||
}
|
||||
|
||||
|
||||
def _get_fa3_from_hub():
|
||||
def _get_kernel_from_hub(kernel_id):
|
||||
if not is_kernels_available():
|
||||
return None
|
||||
else:
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
|
||||
return flash_attn_3_hub
|
||||
if kernel_id not in _KERNEL_REVISION:
|
||||
raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.")
|
||||
kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id))
|
||||
return kernel_hub
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
137
src/diffusers/utils/sage_utils.py
Normal file
137
src/diffusers/utils/sage_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Copyright (c) 2024 by SageAttention, The HuggingFace team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
|
||||
License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Modified from
|
||||
https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py
|
||||
"""
|
||||
|
||||
|
||||
import torch # noqa
|
||||
|
||||
|
||||
SAGE_ATTENTION_DISPATCH = {
|
||||
"sm80": {
|
||||
"func": "sageattn_qk_int8_pv_fp16_cuda",
|
||||
"kwargs": {
|
||||
"tensor_layout": "NHD",
|
||||
"is_causal": False,
|
||||
"sm_scale": None,
|
||||
"return_lse": False,
|
||||
"pv_accum_dtype": "fp32",
|
||||
},
|
||||
},
|
||||
"sm89": {
|
||||
"func": "sageattn_qk_int8_pv_fp8_cuda",
|
||||
"kwargs": {
|
||||
"tensor_layout": "NHD",
|
||||
"is_causal": False,
|
||||
"sm_scale": None,
|
||||
"return_lse": False,
|
||||
"pv_accum_dtype": "fp32+fp16",
|
||||
},
|
||||
},
|
||||
"sm90": {
|
||||
"func": "sageattn_qk_int8_pv_fp8_cuda_sm90",
|
||||
"kwargs": {
|
||||
"tensor_layout": "NHD",
|
||||
"is_causal": False,
|
||||
"sm_scale": None,
|
||||
"return_lse": False,
|
||||
"pv_accum_dtype": "fp32+fp32",
|
||||
},
|
||||
},
|
||||
"sm120": {
|
||||
"func": "sageattn_qk_int8_pv_fp8_cuda",
|
||||
"kwargs": {
|
||||
"tensor_layout": "NHD",
|
||||
"is_causal": False,
|
||||
"qk_quant_gran": "per_warp",
|
||||
"sm_scale": None,
|
||||
"return_lse": False,
|
||||
"pv_accum_dtype": "fp32+fp16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_cuda_version():
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
return major, minor
|
||||
else:
|
||||
raise EnvironmentError("CUDA not found.")
|
||||
|
||||
|
||||
def get_cuda_arch_versions():
|
||||
if not torch.cuda.is_available():
|
||||
EnvironmentError("CUDA not found.")
|
||||
cuda_archs = []
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
cuda_archs.append(f"sm{major}{minor}")
|
||||
return cuda_archs
|
||||
|
||||
|
||||
# Unlike the actual implementation, we just maintain function names rather than actual
|
||||
# implementations.
|
||||
def _get_sage_attn_fn_for_device():
|
||||
"""
|
||||
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute
|
||||
capability.
|
||||
|
||||
Parameters ---------- q : torch.Tensor
|
||||
The query tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
||||
|
||||
k : torch.Tensor
|
||||
The key tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
||||
|
||||
v : torch.Tensor
|
||||
The value tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
||||
|
||||
tensor_layout : str
|
||||
The tensor layout, either "HND" or "NHD". Default: "HND".
|
||||
|
||||
is_causal : bool
|
||||
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False.
|
||||
|
||||
sm_scale : Optional[float]
|
||||
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
||||
|
||||
return_lse : bool
|
||||
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
|
||||
Default: False.
|
||||
|
||||
Returns ------- torch.Tensor
|
||||
The output tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
||||
|
||||
torch.Tensor
|
||||
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape:
|
||||
``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True.
|
||||
|
||||
Note ----
|
||||
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
||||
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
|
||||
- All tensors must be on the same cuda device.
|
||||
"""
|
||||
device_index = torch.cuda.current_device()
|
||||
arch = get_cuda_arch_versions()[device_index]
|
||||
return SAGE_ATTENTION_DISPATCH[arch]
|
||||
Reference in New Issue
Block a user