mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 08:24:32 +08:00
Compare commits
18 Commits
version-ch
...
attn-refac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da34261cc2 | ||
|
|
14cfbab078 | ||
|
|
a69b1e06fc | ||
|
|
345864eb85 | ||
|
|
35e538d46a | ||
|
|
2dc31677e1 | ||
|
|
1066de8c69 | ||
|
|
2d69bacb00 | ||
|
|
0974b4c606 | ||
|
|
cf4b97b233 | ||
|
|
77c4e0932c | ||
|
|
fc322ed052 | ||
|
|
fed2c46482 | ||
|
|
66320f031a | ||
|
|
86a1290e51 | ||
|
|
57f374b87b | ||
|
|
3b2e85d853 | ||
|
|
12b4edc2fe |
1
.github/workflows/build_docker_images.yml
vendored
1
.github/workflows/build_docker_images.yml
vendored
@@ -72,7 +72,6 @@ jobs:
|
||||
image-name:
|
||||
- diffusers-pytorch-cpu
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-pytorch-xformers-cuda
|
||||
- diffusers-pytorch-minimum-cuda
|
||||
- diffusers-doc-builder
|
||||
|
||||
1
.github/workflows/pr_tests.yml
vendored
1
.github/workflows/pr_tests.yml
vendored
@@ -286,4 +286,3 @@ jobs:
|
||||
with:
|
||||
name: pr_main_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -1,56 +1,42 @@
|
||||
FROM ubuntu:20.04
|
||||
FROM python:3.10-slim
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa
|
||||
RUN apt-get -y update && apt-get install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1
|
||||
|
||||
RUN apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
python3.10 \
|
||||
python3-pip \
|
||||
libgl1 \
|
||||
zip \
|
||||
wget \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
matplotlib \
|
||||
setuptools==69.5.1 \
|
||||
bitsandbytes \
|
||||
torchao \
|
||||
gguf \
|
||||
optimum-quanto
|
||||
RUN pip install uv
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
# Extra dependencies
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
hf_transfer \
|
||||
setuptools==69.5.1 \
|
||||
bitsandbytes \
|
||||
torchao \
|
||||
gguf \
|
||||
optimum-quanto
|
||||
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -1,50 +1,37 @@
|
||||
FROM ubuntu:20.04
|
||||
FROM python:3.10-slim
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa
|
||||
RUN apt-get -y update && apt-get install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1
|
||||
|
||||
RUN apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
libgl1 \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers matplotlib \
|
||||
hf_transfer
|
||||
RUN pip install uv
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
# Extra dependencies
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
hf_transfer
|
||||
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -2,11 +2,13 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa
|
||||
&& add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update
|
||||
|
||||
RUN apt install -y bash \
|
||||
build-essential \
|
||||
@@ -16,36 +18,31 @@ RUN apt install -y bash \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
|
||||
RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3.10 -m pip install --no-cache-dir \
|
||||
torchaudio
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
# Extra dependencies
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
hf_transfer \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
pytorch-lightning \
|
||||
pytorch-lightning \
|
||||
hf_transfer
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -2,6 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MINIMUM_SUPPORTED_TORCH_VERSION="2.1.0"
|
||||
ENV MINIMUM_SUPPORTED_TORCHVISION_VERSION="0.16.0"
|
||||
@@ -9,7 +10,8 @@ ENV MINIMUM_SUPPORTED_TORCHAUDIO_VERSION="2.1.0"
|
||||
|
||||
RUN apt-get -y update \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa
|
||||
&& add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update
|
||||
|
||||
RUN apt install -y bash \
|
||||
build-essential \
|
||||
@@ -19,35 +21,31 @@ RUN apt install -y bash \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
|
||||
RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch==$MINIMUM_SUPPORTED_TORCH_VERSION \
|
||||
torchvision==$MINIMUM_SUPPORTED_TORCHVISION_VERSION \
|
||||
torchaudio==$MINIMUM_SUPPORTED_TORCHAUDIO_VERSION \
|
||||
invisible_watermark && \
|
||||
python3.10 -m pip install --no-cache-dir \
|
||||
torchaudio==$MINIMUM_SUPPORTED_TORCHAUDIO_VERSION
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
# Extra dependencies
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
hf_transfer \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
pytorch-lightning \
|
||||
hf_transfer
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -2,50 +2,48 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa
|
||||
&& add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update
|
||||
|
||||
RUN apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3 \
|
||||
python3-pip \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
|
||||
RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
hf_transfer \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
xformers \
|
||||
hf_transfer
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
# Extra dependencies
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer \
|
||||
xformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -14,51 +14,47 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
## 서문 [[preamble]]
|
||||
|
||||
[Diffusers](https://huggingface.co/docs/diffusers/index)는 사전 훈련된 diffusion 모델을 제공하며 추론 및 훈련을 위한 모듈식 툴박스로 사용됩니다.
|
||||
[Diffusers](https://huggingface.co/docs/diffusers/index)는 사전 훈련된 diffusion 모델을 제공하며, 추론과 훈련을 위한 모듈형 툴박스로 활용됩니다.
|
||||
|
||||
이 기술의 실제 적용과 사회에 미칠 수 있는 부정적인 영향을 고려하여 Diffusers 라이브러리의 개발, 사용자 기여 및 사용에 윤리 지침을 제공하는 것이 중요하다고 생각합니다.
|
||||
|
||||
이이 기술을 사용함에 따른 위험은 여전히 검토 중이지만, 몇 가지 예를 들면: 예술가들에 대한 저작권 문제; 딥 페이크의 악용; 부적절한 맥락에서의 성적 콘텐츠 생성; 동의 없는 사칭; 소수자 집단의 억압을 영속화하는 유해한 사회적 편견 등이 있습니다.
|
||||
|
||||
우리는 위험을 지속적으로 추적하고 커뮤니티의 응답과 소중한 피드백에 따라 다음 지침을 조정할 것입니다.
|
||||
이 기술의 실제 적용 사례와 사회에 미칠 수 있는 잠재적 부정적 영향을 고려할 때, Diffusers 라이브러리의 개발, 사용자 기여, 사용에 윤리 지침을 제공하는 것이 중요하다고 생각합니다.
|
||||
|
||||
이 기술 사용과 관련된 위험은 여전히 검토 중이지만, 예를 들면: 예술가의 저작권 문제, 딥페이크 악용, 부적절한 맥락에서의 성적 콘텐츠 생성, 비동의 사칭, 소수자 집단 억압을 영속화하는 유해한 사회적 편견 등이 있습니다.
|
||||
우리는 이러한 위험을 지속적으로 추적하고, 커뮤니티의 반응과 소중한 피드백에 따라 아래 지침을 조정할 것입니다.
|
||||
|
||||
## 범위 [[scope]]
|
||||
|
||||
Diffusers 커뮤니티는 프로젝트의 개발에 다음과 같은 윤리 지침을 적용하며, 특히 윤리적 문제와 관련된 민감한 주제에 대한 커뮤니티의 기여를 조정하는 데 도움을 줄 것입니다.
|
||||
|
||||
Diffusers 커뮤니티는 프로젝트 개발에 다음 윤리 지침을 적용하며, 특히 윤리적 문제와 관련된 민감한 주제에 대해 커뮤니티의 기여를 조정하는 데 도움을 줄 것입니다.
|
||||
|
||||
## 윤리 지침 [[ethical-guidelines]]
|
||||
|
||||
다음 윤리 지침은 일반적으로 적용되지만, 민감한 윤리적 문제와 관련하여 기술적 선택을 할 때 이를 우선적으로 적용할 것입니다. 나아가, 해당 기술의 최신 동향과 관련된 새로운 위험이 발생함에 따라 이러한 윤리 원칙을 조정할 것을 약속드립니다.
|
||||
다음 윤리 지침은 일반적으로 적용되지만, 윤리적으로 민감한 문제와 관련된 기술적 선택을 할 때 우선적으로 적용됩니다. 또한, 해당 기술의 최신 동향과 관련된 새로운 위험이 발생함에 따라 이러한 윤리 원칙을 지속적으로 조정할 것을 약속합니다.
|
||||
|
||||
- **투명성**: 우리는 PR을 관리하고, 사용자에게 우리의 선택을 설명하며, 기술적 의사결정을 내릴 때 투명성을 유지할 것을 약속합니다.
|
||||
- **투명성**: 우리는 PR 관리, 사용자에게 선택의 이유 설명, 기술적 의사결정 과정에서 투명성을 유지할 것을 약속합니다.
|
||||
|
||||
- **일관성**: 우리는 프로젝트 관리에서 사용자들에게 동일한 수준의 관심을 보장하고 기술적으로 안정되고 일관된 상태를 유지할 것을 약속합니다.
|
||||
- **일관성**: 프로젝트 관리에서 모든 사용자에게 동일한 수준의 관심을 보장하고, 기술적으로 안정적이고 일관된 상태를 유지할 것을 약속합니다.
|
||||
|
||||
- **간결성**: Diffusers 라이브러리를 사용하고 활용하기 쉽게 만들기 위해, 프로젝트의 목표를 간결하고 일관성 있게 유지할 것을 약속합니다.
|
||||
- **간결성**: Diffusers 라이브러리를 쉽게 사용하고 활용할 수 있도록, 프로젝트의 목표를 간결하고 일관성 있게 유지할 것을 약속합니다.
|
||||
|
||||
- **접근성**: Diffusers 프로젝트는 기술적 전문 지식 없어도 프로젝트 운영에 참여할 수 있는 기여자의 진입장벽을 낮춥니다. 이를 통해 연구 결과물이 커뮤니티에 더 잘 접근할 수 있게 됩니다.
|
||||
- **접근성**: Diffusers 프로젝트는 기술적 전문지식이 없어도 기여할 수 있도록 진입장벽을 낮춥니다. 이를 통해 연구 결과물이 커뮤니티에 더 잘 접근될 수 있습니다.
|
||||
|
||||
- **재현성**: 우리는 Diffusers 라이브러리를 통해 제공되는 업스트림(upstream) 코드, 모델 및 데이터셋의 재현성에 대해 투명하게 공개할 것을 목표로 합니다.
|
||||
|
||||
- **책임**: 우리는 커뮤니티와 팀워크를 통해, 이 기술의 잠재적인 위험과 위험을 예측하고 완화하는 데 대한 공동 책임을 가지고 있습니다.
|
||||
- **재현성**: 우리는 Diffusers 라이브러리를 통해 제공되는 업스트림 코드, 모델, 데이터셋의 재현성에 대해 투명하게 공개하는 것을 목표로 합니다.
|
||||
|
||||
- **책임**: 커뮤니티와 팀워크를 통해, 이 기술의 잠재적 위험을 예측하고 완화하는 데 공동 책임을 집니다.
|
||||
|
||||
## 구현 사례: 안전 기능과 메커니즘 [[examples-of-implementations-safety-features-and-mechanisms]]
|
||||
|
||||
팀은 diffusion 기술과 관련된 잠재적인 윤리 및 사회적 위험에 대처하기 위한 기술적 및 비기술적 도구를 제공하고자 하고 있습니다. 또한, 커뮤니티의 참여는 이러한 기능의 구현하고 우리와 함께 인식을 높이는 데 매우 중요합니다.
|
||||
팀은 diffusion 기술과 관련된 잠재적 윤리 및 사회적 위험에 대응하기 위해 기술적·비기술적 도구를 제공하고자 노력하고 있습니다. 또한, 커뮤니티의 참여는 이러한 기능 구현과 인식 제고에 매우 중요합니다.
|
||||
|
||||
- [**커뮤니티 탭**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): 이를 통해 커뮤니티는 프로젝트에 대해 토론하고 더 나은 협력을 할 수 있습니다.
|
||||
- [**커뮤니티 탭**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): 커뮤니티가 프로젝트에 대해 토론하고 더 나은 협업을 할 수 있도록 지원합니다.
|
||||
|
||||
- **편향 탐색 및 평가**: Hugging Face 팀은 Stable Diffusion 모델의 편향성을 대화형으로 보여주는 [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)을 제공합니다. 이런 의미에서, 우리는 편향 탐색 및 평가를 지원하고 장려합니다.
|
||||
- **편향 탐색 및 평가**: Hugging Face 팀은 Stable Diffusion 모델의 편향성을 대화형으로 보여주는 [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)를 제공합니다. 우리는 이러한 편향 탐색과 평가를 지원하고 장려합니다.
|
||||
|
||||
- **배포에서의 안전 유도**
|
||||
|
||||
- [**안전한 Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): 이는 필터되지 않은 웹 크롤링 데이터셋으로 훈련된 Stable Diffusion과 같은 모델이 부적절한 변질에 취약한 문제를 완화합니다. 관련 논문: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105).
|
||||
- [**안전한 Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): 필터링되지 않은 웹 크롤링 데이터셋으로 훈련된 Stable Diffusion과 같은 모델이 부적절하게 변질되는 문제를 완화합니다. 관련 논문: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105).
|
||||
|
||||
- [**안전 검사기**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): 이미지가 생성된 후에 이미자가 임베딩 공간에서 일련의 하드코딩된 유해 개념의 클래스일 확률을 확인하고 비교합니다. 유해 개념은 역공학을 방지하기 위해 의도적으로 숨겨져 있습니다.
|
||||
- [**안전 검사기**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): 생성된 이미지가 임베딩 공간에서 하드코딩된 유해 개념 클래스와 일치할 확률을 확인하고 비교합니다. 유해 개념은 역공학을 방지하기 위해 의도적으로 숨겨져 있습니다.
|
||||
|
||||
- **Hub에서의 단계적인 배포**: 특히 민감한 상황에서는 일부 리포지토리에 대한 접근을 제한해야 합니다. 이 단계적인 배포는 중간 단계로, 리포지토리 작성자가 사용에 대한 더 많은 통제력을 갖게 합니다.
|
||||
- **Hub에서의 단계적 배포**: 특히 민감한 상황에서는 일부 리포지토리에 대한 접근을 제한할 수 있습니다. 단계적 배포는 리포지토리 작성자가 사용에 대해 더 많은 통제권을 갖도록 하는 중간 단계입니다.
|
||||
|
||||
- **라이선싱**: [OpenRAILs](https://huggingface.co/blog/open_rail)와 같은 새로운 유형의 라이선싱을 통해 자유로운 접근을 보장하면서도 더 책임 있는 사용을 위한 일련의 제한을 둘 수 있습니다.
|
||||
- **라이선싱**: [OpenRAILs](https://huggingface.co/blog/open_rail)와 같은 새로운 유형의 라이선스를 통해 자유로운 접근을 보장하면서도 보다 책임 있는 사용을 위한 일련의 제한을 둘 수 있습니다.
|
||||
|
||||
@@ -1338,7 +1338,7 @@ def main(args):
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=vae.dtype
|
||||
)
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
if train_dataset.custom_instance_prompts:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, prompt_embeds_mask = compute_text_embeddings(
|
||||
|
||||
@@ -151,8 +151,8 @@ def _register_attention_processors_metadata():
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_2d import BasicTransformerBlock
|
||||
from ..models.transformers.transformer_bria import BriaTransformerBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
||||
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
|
||||
@@ -17,7 +17,10 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from ..models._modeling_parallel import (
|
||||
ContextParallelConfig,
|
||||
|
||||
@@ -21,10 +21,8 @@ import torch.nn.functional as F
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
|
||||
from .attention_processor import Attention, AttentionProcessor
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
@@ -505,19 +503,16 @@ class AttentionModuleMixin:
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
def _chunked_feed_forward(*args, **kwargs):
|
||||
deprecate(
|
||||
"_chunked_feed_forward",
|
||||
"1.0.0",
|
||||
"Importing `_chunked_feed_forward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.modeling_common import _chunked_feed_forward` instead.",
|
||||
standard_warn=False,
|
||||
)
|
||||
return ff_output
|
||||
from .transformers.modeling_common import _chunked_feed_forward
|
||||
|
||||
return _chunked_feed_forward(*args, **kwargs)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -577,161 +572,16 @@ class JointTransformerBlock(nn.Module):
|
||||
processing of `context` conditions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
context_pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
use_dual_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_dual_attention = use_dual_attention
|
||||
self.context_pre_only = context_pre_only
|
||||
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
||||
|
||||
if use_dual_attention:
|
||||
self.norm1 = SD35AdaLayerNormZeroX(dim)
|
||||
else:
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
if context_norm_type == "ada_norm_continous":
|
||||
self.norm1_context = AdaLayerNormContinuous(
|
||||
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
|
||||
)
|
||||
elif context_norm_type == "ada_norm_zero":
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
||||
)
|
||||
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
processor = JointAttnProcessor2_0()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecate(
|
||||
"JointTransformerBlock",
|
||||
"1.0.0",
|
||||
"Importing `JointTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_sd3 import SD3TransformerBlock` instead.",
|
||||
standard_warn=False,
|
||||
)
|
||||
from .transformers.transformer_sd3 import SD3TransformerBlock
|
||||
|
||||
if use_dual_attention:
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
if not context_pre_only:
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
else:
|
||||
self.norm2_context = None
|
||||
self.ff_context = None
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
hidden_states, emb=temb
|
||||
)
|
||||
else:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
if self.context_pre_only:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
else:
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
if self.use_dual_attention:
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
||||
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
||||
hidden_states = hidden_states + attn_output2
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
if self.context_pre_only:
|
||||
encoder_hidden_states = None
|
||||
else:
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
context_ff_output = _chunked_feed_forward(
|
||||
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
|
||||
)
|
||||
else:
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
return SD3TransformerBlock(*args, **kwargs)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -770,300 +620,16 @@ class BasicTransformerBlock(nn.Module):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_zero":
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm1 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecate(
|
||||
"BasicTransformerBlock",
|
||||
"1.0.0",
|
||||
"Importing `BasicTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_2d import BasicTransformerBlock` instead.",
|
||||
standard_warn=False,
|
||||
)
|
||||
from .transformers.transformer_2d import BasicTransformerBlock
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
if norm_type == "ada_norm":
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
if norm_type == "ada_norm_single": # For Latte
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
if norm_type == "ada_norm_continuous":
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
||||
|
||||
# 5. Scale-shift for PixArt-Alpha.
|
||||
if norm_type == "ada_norm_single":
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
# 1. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 1.2 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm2(hidden_states, timestep)
|
||||
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
# i2vgen doesn't have this norm 🤷♂️
|
||||
if self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif not self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
return BasicTransformerBlock(*args, **kwargs)
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
@@ -1081,38 +647,16 @@ class LuminaFeedForward(nn.Module):
|
||||
dimension. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: int,
|
||||
multiple_of: Optional[int] = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# custom hidden_size factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecate(
|
||||
"LuminaFeedForward",
|
||||
"1.0.0",
|
||||
"Importing `LuminaFeedForward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_lumina2 import LuminaFeedForward` instead.",
|
||||
standard_warn=False,
|
||||
)
|
||||
from .transformers.transformer_lumina2 import LuminaFeedForward
|
||||
|
||||
self.linear_1 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_2 = nn.Linear(
|
||||
inner_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_3 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.silu = FP32SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
|
||||
return LuminaFeedForward(*args, **kwargs)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -1128,193 +672,29 @@ class TemporalBasicTransformerBlock(nn.Module):
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_mix_inner_dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.is_res = dim == time_mix_inner_dim
|
||||
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.ff_in = FeedForward(
|
||||
dim,
|
||||
dim_out=time_mix_inner_dim,
|
||||
activation_fn="geglu",
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecate(
|
||||
"TemporalBasicTransformerBlock",
|
||||
"1.0.0",
|
||||
"Importing `TemporalBasicTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_temporal import TemporalBasicTransformerBlock` instead.",
|
||||
standard_warn=False,
|
||||
)
|
||||
from .transformers.transformer_temporal import TemporalBasicTransformerBlock
|
||||
|
||||
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=time_mix_inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
cross_attention_dim=None,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
|
||||
self.attn2 = Attention(
|
||||
query_dim=time_mix_inner_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
|
||||
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = None
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
|
||||
self._chunk_dim = 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: int,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
batch_frames, seq_length, channels = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
hidden_states = self.ff_in(hidden_states)
|
||||
|
||||
if self.is_res:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.is_res:
|
||||
hidden_states = ff_output + hidden_states
|
||||
else:
|
||||
hidden_states = ff_output
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
||||
|
||||
return hidden_states
|
||||
return TemporalBasicTransformerBlock(*args, **kwargs)
|
||||
|
||||
|
||||
class SkipFFTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
kv_input_dim: int,
|
||||
kv_input_dim_proj_use_bias: bool,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if kv_input_dim != dim:
|
||||
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
|
||||
else:
|
||||
self.kv_mapper = None
|
||||
|
||||
self.norm1 = RMSNorm(dim, 1e-06)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=attention_out_bias,
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecate(
|
||||
"SkipFFTransformerBlock",
|
||||
"1.0.0",
|
||||
"Importing `SkipFFTransformerBlock` from `diffusers.models.attention` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.uvit_2d import SkipFFTransformerBlock` instead.",
|
||||
standard_warn=False,
|
||||
)
|
||||
from .unets.uvit_2d import SkipFFTransformerBlock
|
||||
|
||||
self.norm2 = RMSNorm(dim, 1e-06)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
|
||||
if self.kv_mapper is not None:
|
||||
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
return SkipFFTransformerBlock(*args, **kwargs)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -1679,50 +1059,13 @@ class FeedForward(nn.Module):
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecate(
|
||||
"FeedForward",
|
||||
"1.0.0",
|
||||
"Importing `FeedForward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.modeling_common import FeedForward` instead.",
|
||||
standard_warn=False,
|
||||
)
|
||||
from .transformers.modeling_common import FeedForward
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
||||
elif activation_fn == "geglu":
|
||||
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "swiglu":
|
||||
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "linear-silu":
|
||||
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(act_fn)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
return FeedForward(*args, **kwargs)
|
||||
|
||||
@@ -24,8 +24,8 @@ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, u
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.modeling_common import Transformer2DModelOutput
|
||||
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, u
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..controlnets.controlnet import zero_module
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.modeling_common import Transformer2DModelOutput
|
||||
from ..transformers.transformer_qwenimage import (
|
||||
QwenEmbedRope,
|
||||
QwenImageTransformerBlock,
|
||||
|
||||
@@ -23,9 +23,9 @@ from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
from ..transformers.modeling_common import Transformer2DModelOutput
|
||||
from ..transformers.sana_transformer import SanaTransformerBlock
|
||||
from .controlnet import zero_module
|
||||
|
||||
|
||||
@@ -22,12 +22,11 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import JointTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
|
||||
from ..transformers.modeling_common import Transformer2DModelOutput
|
||||
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock, SD3TransformerBlock
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
|
||||
|
||||
@@ -132,7 +131,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
SD3TransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
|
||||
@@ -1530,7 +1530,7 @@ class ImageProjection(nn.Module):
|
||||
class IPAdapterFullImageProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
from .transformers.modeling_common import FeedForward
|
||||
|
||||
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
|
||||
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||
@@ -1542,7 +1542,7 @@ class IPAdapterFullImageProjection(nn.Module):
|
||||
class IPAdapterFaceIDImageProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
from .transformers.modeling_common import FeedForward
|
||||
|
||||
self.num_tokens = num_tokens
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
@@ -2219,7 +2219,7 @@ class IPAdapterPlusImageProjectionBlock(nn.Module):
|
||||
ffn_ratio: float = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
from .transformers.modeling_common import FeedForward
|
||||
|
||||
self.ln0 = nn.LayerNorm(embed_dims)
|
||||
self.ln1 = nn.LayerNorm(embed_dims)
|
||||
@@ -2334,7 +2334,7 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
||||
ffproj_ratio: int = 2,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
from .transformers.modeling_common import FeedForward
|
||||
|
||||
self.num_tokens = num_tokens
|
||||
self.embed_dim = embed_dims
|
||||
@@ -2404,7 +2404,7 @@ class IPAdapterTimeImageProjectionBlock(nn.Module):
|
||||
ffn_ratio: int = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
from .transformers.modeling_common import FeedForward
|
||||
|
||||
self.ln0 = nn.LayerNorm(hidden_dim)
|
||||
self.ln1 = nn.LayerNorm(hidden_dim)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, deprecate
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,8 +17,7 @@ class AutoencoderKLOutput(BaseOutput):
|
||||
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
class Transformer2DModelOutput:
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
@@ -28,4 +27,13 @@ class Transformer2DModelOutput(BaseOutput):
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecate(
|
||||
"Transformer2DModelOutput",
|
||||
"1.0.0",
|
||||
"Importing `Transformer2DModelOutput` from `diffusers.models.modeling_outputs` is deprecated. Please use `from diffusers.models.transformers.modeling_common import Transformer2DModelOutput` instead.",
|
||||
standard_warn=False,
|
||||
)
|
||||
from .transformers.modeling_common import Transformer2DModelOutput
|
||||
|
||||
return Transformer2DModelOutput(*args, **kwargs)
|
||||
|
||||
@@ -30,9 +30,9 @@ from ..attention_processor import (
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormZero, FP32LayerNorm
|
||||
from .modeling_common import Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -194,7 +194,8 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class AuraFlowJointTransformerBlock(nn.Module):
|
||||
# Copied from diffusers.models.transformers.transformer_sd3.SD3TransformerBlock with SD3->AuraFlow
|
||||
class AuraFlowTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
|
||||
|
||||
@@ -337,7 +338,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
|
||||
self.joint_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
AuraFlowJointTransformerBlock(
|
||||
AuraFlowTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
|
||||
@@ -22,13 +22,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -22,12 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -19,15 +19,348 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..attention import Attention, GatedSelfAttentionDense
|
||||
from ..embeddings import PatchEmbed, SinusoidalPositionalEmbedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock
|
||||
class DiTTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_zero":
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm1 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
if norm_type == "ada_norm":
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
if norm_type == "ada_norm_single":
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
if norm_type == "ada_norm_continuous":
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
||||
|
||||
# 5. Scale-shift for PixArt-Alpha.
|
||||
if norm_type == "ada_norm_single":
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
# 1. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 1.2 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm2(hidden_states, timestep)
|
||||
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
# i2vgen doesn't have this norm 🤷♂️
|
||||
if self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif not self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).
|
||||
@@ -121,7 +454,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
DiTTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from .modeling_common import Transformer2DModelOutput
|
||||
from .transformer_2d import Transformer2DModel
|
||||
|
||||
|
||||
|
||||
@@ -19,16 +19,15 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
|
||||
from ..embeddings import (
|
||||
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
||||
PatchEmbed,
|
||||
PixArtAlphaTextProjection,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -12,18 +12,359 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ...utils import logging
|
||||
from ..attention import Attention, GatedSelfAttentionDense
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..embeddings import (
|
||||
PatchEmbed,
|
||||
PixArtAlphaTextProjection,
|
||||
SinusoidalPositionalEmbedding,
|
||||
get_1d_sincos_pos_embed_from_grid,
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock
|
||||
class LatteTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_zero":
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm1 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
if norm_type == "ada_norm":
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
if norm_type == "ada_norm_single":
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
if norm_type == "ada_norm_continuous":
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
||||
|
||||
# 5. Scale-shift for PixArt-Alpha.
|
||||
if norm_type == "ada_norm_single":
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
# 1. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 1.2 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm2(hidden_states, timestep)
|
||||
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
# i2vgen doesn't have this norm 🤷♂️
|
||||
if self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif not self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
@@ -110,7 +451,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
# 2. Define spatial transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
LatteTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
@@ -130,7 +471,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
# 3. Define temporal transformers blocks
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
LatteTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
|
||||
@@ -19,15 +19,15 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention, LuminaAttnProcessor2_0
|
||||
from ..embeddings import (
|
||||
LuminaCombinedTimestepCaptionEmbedding,
|
||||
LuminaPatchEmbed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
||||
from .modeling_common import Transformer2DModelOutput
|
||||
from .transformer_lumina2 import LuminaFeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
114
src/diffusers/models/transformers/modeling_common.py
Normal file
114
src/diffusers/models/transformers/modeling_common.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...utils import BaseOutput, deprecate
|
||||
from ..activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU
|
||||
|
||||
|
||||
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
return ff_output
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
||||
elif activation_fn == "geglu":
|
||||
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "swiglu":
|
||||
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "linear-silu":
|
||||
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(act_fn)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
@@ -18,17 +18,349 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention import GatedSelfAttentionDense
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, SinusoidalPositionalEmbedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock
|
||||
class PixArtTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_zero":
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm1 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
if norm_type == "ada_norm":
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
if norm_type == "ada_norm_single":
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
if norm_type == "ada_norm_continuous":
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
||||
|
||||
# 5. Scale-shift for PixArt-Alpha.
|
||||
if norm_type == "ada_norm_single":
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
# 1. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 1.2 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm2(hidden_states, timestep)
|
||||
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
# i2vgen doesn't have this norm 🤷♂️
|
||||
if self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif not self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in PixArt family of models (https://huggingface.co/papers/2310.00426,
|
||||
@@ -151,7 +483,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
PixArtTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -7,8 +7,8 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention import Attention, GatedSelfAttentionDense
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
@@ -16,8 +16,345 @@ from ..attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from .modeling_common import FeedForward, _chunked_feed_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock
|
||||
class PriorTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_zero":
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm1 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
if norm_type == "ada_norm":
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
if norm_type == "ada_norm_single":
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
if norm_type == "ada_norm_continuous":
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
||||
|
||||
# 5. Scale-shift for PixArt-Alpha.
|
||||
if norm_type == "ada_norm_single":
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
# 1. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 1.2 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm2(hidden_states, timestep)
|
||||
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
# i2vgen doesn't have this norm 🤷♂️
|
||||
if self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif not self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -133,7 +470,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
PriorTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
|
||||
@@ -27,9 +27,9 @@ from ..attention_processor import (
|
||||
SanaLinearAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
from .modeling_common import Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -23,10 +23,10 @@ import torch.utils.checkpoint
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_2d import Transformer2DModelOutput
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -19,16 +19,352 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import LegacyConfigMixin, register_to_config
|
||||
from ...utils import deprecate, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..attention import Attention, GatedSelfAttentionDense
|
||||
from ..embeddings import (
|
||||
ImagePositionalEmbeddings,
|
||||
PatchEmbed,
|
||||
PixArtAlphaTextProjection,
|
||||
SinusoidalPositionalEmbedding,
|
||||
)
|
||||
from ..modeling_utils import LegacyModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_zero":
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm1 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
if norm_type == "ada_norm":
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
if norm_type == "ada_norm_single":
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
if norm_type == "ada_norm_continuous":
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
||||
|
||||
# 5. Scale-shift for PixArt-Alpha.
|
||||
if norm_type == "ada_norm_single":
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
# 1. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 1.2 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm2(hidden_states, timestep)
|
||||
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
# i2vgen doesn't have this norm 🤷♂️
|
||||
if self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif not self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Transformer2DModelOutput(Transformer2DModelOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
|
||||
|
||||
@@ -22,13 +22,12 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import AllegroAttnProcessor2_0, Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -10,13 +10,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -24,12 +24,12 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
from .transformer_flux import FluxAttention, FluxAttnProcessor
|
||||
|
||||
|
||||
|
||||
@@ -20,12 +20,11 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -22,13 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LayerNorm, RMSNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -22,12 +22,11 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import is_torchvision_available
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
|
||||
@@ -22,11 +22,11 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
@@ -34,9 +34,9 @@ from ..embeddings import (
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -6,8 +6,8 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.modeling_common import Transformer2DModelOutput
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
|
||||
@@ -23,7 +23,6 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
@@ -33,9 +32,9 @@ from ..embeddings import (
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -23,9 +23,9 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
from .modeling_common import Transformer2DModelOutput
|
||||
from .transformer_hunyuan_video import (
|
||||
HunyuanVideoConditionEmbedding,
|
||||
HunyuanVideoPatchEmbed,
|
||||
|
||||
@@ -25,13 +25,13 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -23,17 +23,66 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..activations import FP32SiLU
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
||||
from .modeling_common import Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
hidden_size (`int`):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
||||
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
||||
of this value.
|
||||
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
||||
dimension. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: int,
|
||||
multiple_of: Optional[int] = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# custom hidden_size factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.linear_1 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_2 = nn.Linear(
|
||||
inner_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_3 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.silu = FP32SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -23,13 +23,12 @@ from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -23,9 +23,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, RMSNorm
|
||||
from .modeling_common import Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -26,14 +26,14 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -20,7 +20,6 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward, JointTransformerBlock
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
@@ -28,14 +27,184 @@ from ..attention_processor import (
|
||||
JointAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SD3TransformerBlock(nn.Module):
|
||||
r"""
|
||||
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://huggingface.co/papers/2403.03206
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
||||
processing of `context` conditions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
context_pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
use_dual_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_dual_attention = use_dual_attention
|
||||
self.context_pre_only = context_pre_only
|
||||
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
||||
|
||||
if use_dual_attention:
|
||||
self.norm1 = SD35AdaLayerNormZeroX(dim)
|
||||
else:
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
if context_norm_type == "ada_norm_continous":
|
||||
self.norm1_context = AdaLayerNormContinuous(
|
||||
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
|
||||
)
|
||||
elif context_norm_type == "ada_norm_zero":
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
||||
)
|
||||
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
processor = JointAttnProcessor2_0()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
if use_dual_attention:
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
if not context_pre_only:
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
else:
|
||||
self.norm2_context = None
|
||||
self.ff_context = None
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
hidden_states, emb=temb
|
||||
)
|
||||
else:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
if self.context_pre_only:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
else:
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
if self.use_dual_attention:
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
||||
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
||||
hidden_states = hidden_states + attn_output2
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
if self.context_pre_only:
|
||||
encoder_hidden_states = None
|
||||
else:
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
context_ff_output = _chunked_feed_forward(
|
||||
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
|
||||
)
|
||||
else:
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SD3SingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -155,7 +324,7 @@ class SD3Transformer2DModel(
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
SD3TransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
@@ -32,9 +32,9 @@ from ..embeddings import (
|
||||
get_1d_rotary_pos_embed,
|
||||
get_1d_sincos_pos_embed_from_grid,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin, get_parameter_dtype
|
||||
from ..normalization import FP32LayerNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -19,10 +19,13 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import AlphaBlender
|
||||
from .modeling_common import FeedForward, _chunked_feed_forward
|
||||
from .transformer_2d import BasicTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,6 +41,136 @@ class TransformerTemporalModelOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class TemporalBasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block for video like data.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
time_mix_inner_dim (`int`): The number of channels for temporal attention.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_mix_inner_dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.is_res = dim == time_mix_inner_dim
|
||||
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.ff_in = FeedForward(
|
||||
dim,
|
||||
dim_out=time_mix_inner_dim,
|
||||
activation_fn="geglu",
|
||||
)
|
||||
|
||||
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=time_mix_inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
cross_attention_dim=None,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
|
||||
self.attn2 = Attention(
|
||||
query_dim=time_mix_inner_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
|
||||
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = None
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
|
||||
self._chunk_dim = 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: int,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
batch_frames, seq_length, channels = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
hidden_states = self.ff_in(hidden_states)
|
||||
|
||||
if self.is_res:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.is_res:
|
||||
hidden_states = ff_output + hidden_states
|
||||
else:
|
||||
hidden_states = ff_output
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
@@ -24,13 +24,13 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -21,11 +21,11 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
from .modeling_common import FeedForward, Transformer2DModelOutput
|
||||
from .transformer_wan import (
|
||||
WanAttention,
|
||||
WanAttnProcessor,
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import logging
|
||||
from ..activations import get_activation
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
@@ -33,6 +33,7 @@ from ..attention_processor import (
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.modeling_common import FeedForward
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
UNetMidBlock3DCrossAttn,
|
||||
|
||||
@@ -24,7 +24,6 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
@@ -41,7 +40,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ..transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ..transformers.transformer_2d import Transformer2DModel
|
||||
from ..transformers.transformer_2d import BasicTransformerBlock, Transformer2DModel
|
||||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
@@ -34,6 +34,79 @@ from ..embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import GlobalResponseNorm, RMSNorm
|
||||
from ..resnet import Downsample2D, Upsample2D
|
||||
from ..transformers.transformer_2d import BasicTransformerBlock
|
||||
|
||||
|
||||
class SkipFFTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
kv_input_dim: int,
|
||||
kv_input_dim_proj_use_bias: bool,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: int = None,
|
||||
attention_bias: bool = False,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if kv_input_dim != dim:
|
||||
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
|
||||
else:
|
||||
self.kv_mapper = None
|
||||
|
||||
self.norm1 = RMSNorm(dim, 1e-06)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, 1e-06)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
|
||||
if self.kv_mapper is not None:
|
||||
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
@@ -13,12 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
from ...pipelines import FluxPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
@@ -104,48 +104,6 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Adapted from the original implementation.
|
||||
def prepare_latents_img2img(
|
||||
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
latent_channels = vae.config.latent_channels
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != latent_channels:
|
||||
image_latents = _encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = scheduler.scale_noise(image_latents, timestep, noise)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
@@ -160,6 +118,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# TODO: align this with Qwen patchifier
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
@@ -168,35 +127,6 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
return latents
|
||||
|
||||
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# Cannot use "# Copied from" because it introduces weird indentation errors.
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
def _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
@@ -231,96 +161,6 @@ def _get_initial_timesteps_and_optionals(
|
||||
return timesteps, num_inference_steps, sigmas, guidance
|
||||
|
||||
|
||||
class FluxInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Input processing step that:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
|
||||
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
|
||||
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
|
||||
"have a final batch_size of batch_size * num_images_per_prompt."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
"pooled_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
# TODO: support negative embeddings?
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"pooled_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
# TODO: support negative embeddings?
|
||||
]
|
||||
|
||||
def check_inputs(self, components, block_state):
|
||||
if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
|
||||
if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
|
||||
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
|
||||
f" {block_state.pooled_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
# TODO: consider adding negative embeddings?
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@@ -389,6 +229,10 @@ class FluxSetTimestepsStep(ModularPipelineBlocks):
|
||||
block_state.sigmas = sigmas
|
||||
block_state.guidance = guidance
|
||||
|
||||
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
||||
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -432,11 +276,6 @@ class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam(
|
||||
"latent_timestep",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image generation",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
|
||||
]
|
||||
|
||||
@@ -484,8 +323,6 @@ class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
block_state.sigmas = sigmas
|
||||
block_state.guidance = guidance
|
||||
|
||||
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -524,11 +361,6 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
),
|
||||
OutputParam(
|
||||
"latent_image_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from the image sequence needed for RoPE",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -552,20 +384,13 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over
|
||||
# the packing methods here. So, for example, `comp._pack_latents()` won't work if we were
|
||||
# to go with the "# Copied from ..." approach. Or maybe there's a way?
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -573,12 +398,11 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
# TODO: move packing latents code to a patchifier
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -587,12 +411,11 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
|
||||
block_state.latents = self.prepare_latents(
|
||||
components,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
@@ -612,82 +435,124 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the latents for the image-to-image generation process"
|
||||
return "Step that adds noise to image latents for image-to-image. Should be run after `set_timesteps`,"
|
||||
" `prepare_latents`. Both noise and image latents should already be patchified."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
|
||||
description="The initial random noised, can be generated in prepare latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
name="image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
|
||||
description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
name="timesteps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
),
|
||||
OutputParam(
|
||||
"latent_image_ids",
|
||||
name="initial_noise",
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from the image sequence needed for RoPE",
|
||||
description="The initial random noised used for inpainting denoising.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(image_latents, latents):
|
||||
if image_latents.shape[0] != latents.shape[0]:
|
||||
raise ValueError(
|
||||
f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
|
||||
)
|
||||
|
||||
if image_latents.ndim != 3:
|
||||
raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
self.check_inputs(image_latents=block_state.image_latents, latents=block_state.latents)
|
||||
|
||||
# TODO: implement `check_inputs`
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
if block_state.latents is None:
|
||||
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
|
||||
components.vae,
|
||||
components.scheduler,
|
||||
block_state.image_latents,
|
||||
block_state.latent_timestep,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
)
|
||||
# prepare latent timestep
|
||||
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
|
||||
|
||||
# make copy of initial_noise
|
||||
block_state.initial_noise = block_state.latents
|
||||
|
||||
# scale noise
|
||||
block_state.latents = components.scheduler.scale_noise(
|
||||
block_state.image_latents, latent_timestep, block_state.latents
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="img_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the image latents, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device, dtype = prompt_embeds.device, prompt_embeds.dtype
|
||||
block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
|
||||
device=prompt_embeds.device, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
|
||||
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
|
||||
@@ -76,18 +76,17 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
description="Pooled prompt embeddings",
|
||||
),
|
||||
InputParam(
|
||||
"text_ids",
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from text sequence needed for RoPE",
|
||||
),
|
||||
InputParam(
|
||||
"latent_image_ids",
|
||||
"img_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from image sequence needed for RoPE",
|
||||
),
|
||||
# TODO: guidance
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -101,8 +100,8 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
pooled_projections=block_state.pooled_prompt_embeds,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
txt_ids=block_state.text_ids,
|
||||
img_ids=block_state.latent_image_ids,
|
||||
txt_ids=block_state.txt_ids,
|
||||
img_ids=block_state.img_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
block_state.noise_pred = noise_pred
|
||||
@@ -195,9 +194,6 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
||||
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
||||
components.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
|
||||
@@ -67,89 +67,148 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class FluxVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "Flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Encoder step that encode the input image into a latent representation"
|
||||
return "Image Preprocess step. Resizing is needed in Flux Kontext (will be implemented later.)"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
"preprocess_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
|
||||
),
|
||||
OutputParam(name="processed_image"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if block_state.resized_image is None and block_state.image is None:
|
||||
raise ValueError("`resized_image` and `image` cannot be None at the same time")
|
||||
|
||||
if block_state.resized_image is None:
|
||||
image = block_state.image
|
||||
self.check_inputs(
|
||||
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
|
||||
)
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
else:
|
||||
width, height = block_state.resized_image[0].size
|
||||
image = block_state.resized_image
|
||||
|
||||
block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_name: str = "processed_image",
|
||||
output_name: str = "image_latents",
|
||||
):
|
||||
"""Initialize a VAE encoder step for converting images to latent representations.
|
||||
|
||||
Both the input and output names are configurable so this block can be configured to process to different image
|
||||
inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
|
||||
|
||||
Args:
|
||||
input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
|
||||
Examples: "processed_image" or "processed_control_image"
|
||||
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
|
||||
Examples: "image_latents" or "control_image_latents"
|
||||
|
||||
Examples:
|
||||
# Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep()
|
||||
|
||||
# Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep(
|
||||
input_name="processed_control_image", output_name="control_image_latents"
|
||||
)
|
||||
"""
|
||||
self._image_input_name = input_name
|
||||
self._image_latents_output_name = output_name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [ComponentSpec("vae", AutoencoderKL)]
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
self._image_latents_output_name,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation",
|
||||
description="The latents representing the reference image",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
image = getattr(block_state, self._image_input_name)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components.vae, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
# Encode image into latents
|
||||
image_latents = encode_vae_image(image=image, vae=components.vae, generator=block_state.generator)
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
@@ -161,7 +220,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generate text_embeddings to guide the video generation"
|
||||
return "Text Encoder step that generate text_embeddings to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -172,10 +231,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
ComponentSpec("tokenizer_2", T5TokenizerFast),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
@@ -200,12 +255,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"text_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="ids from the text sequence for RoPE",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -216,16 +265,10 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
@staticmethod
|
||||
def _get_t5_prompt_embeds(
|
||||
components,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int,
|
||||
max_sequence_length: int,
|
||||
device: torch.device,
|
||||
components, prompt: Union[str, List[str]], max_sequence_length: int, device: torch.device
|
||||
):
|
||||
dtype = components.text_encoder_2.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(components, TextualInversionLoaderMixin):
|
||||
prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)
|
||||
@@ -251,23 +294,11 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
def _get_clip_prompt_embeds(
|
||||
components,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int,
|
||||
device: torch.device,
|
||||
):
|
||||
def _get_clip_prompt_embeds(components, prompt: Union[str, List[str]], device: torch.device):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(components, TextualInversionLoaderMixin):
|
||||
prompt = components.maybe_convert_prompt(prompt, components.tokenizer)
|
||||
@@ -297,10 +328,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
@@ -309,34 +336,11 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or components._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
@@ -361,12 +365,10 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
components,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
|
||||
components,
|
||||
prompt=prompt_2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
@@ -381,10 +383,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(components.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -400,14 +399,13 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
if block_state.joint_attention_kwargs is not None
|
||||
else None
|
||||
)
|
||||
(block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt(
|
||||
block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt(
|
||||
components,
|
||||
prompt=block_state.prompt,
|
||||
prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
device=block_state.device,
|
||||
num_images_per_prompt=1, # TODO: hardcoded for now.
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
lora_scale=block_state.text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
236
src/diffusers/modular_pipelines/flux/inputs.py
Normal file
236
src/diffusers/modular_pipelines/flux/inputs.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ...pipelines import FluxPipeline
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import InputParam, OutputParam
|
||||
|
||||
# TODO: consider making these common utilities for modular if they are not pipeline-specific.
|
||||
from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
|
||||
class FluxTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Text input processing step that standardizes text embeddings for the pipeline.\n"
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
"pooled_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
# TODO: support negative embeddings?
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"pooled_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
# TODO: support negative embeddings?
|
||||
]
|
||||
|
||||
def check_inputs(self, components, block_state):
|
||||
if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
|
||||
if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
|
||||
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
|
||||
f" {block_state.pooled_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
# TODO: consider adding negative embeddings?
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# Adapted from `QwenImageInputsDynamicStep`
|
||||
class FluxInputsDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
|
||||
]
|
||||
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
if not hasattr(block_state, "image_height"):
|
||||
block_state.image_height = height
|
||||
if not hasattr(block_state, "image_width"):
|
||||
block_state.image_width = width
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
# TODO: Implement patchifier for Flux.
|
||||
latent_height, latent_width = image_latent_tensor.shape[2:]
|
||||
image_latent_tensor = FluxPipeline._pack_latents(
|
||||
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
|
||||
)
|
||||
|
||||
# 3. Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -18,21 +18,41 @@ from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
FluxImg2ImgPrepareLatentsStep,
|
||||
FluxImg2ImgSetTimestepsStep,
|
||||
FluxInputStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxRoPEInputsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxDenoiseStep
|
||||
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
|
||||
from .encoders import FluxProcessImagesInputStep, FluxTextEncoderStep, FluxVaeEncoderDynamicStep
|
||||
from .inputs import FluxInputsDynamicStep, FluxTextInputStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# vae encoder (run before before_denoise)
|
||||
FluxImg2ImgVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", FluxProcessImagesInputStep()),
|
||||
("encode", FluxVaeEncoderDynamicStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
block_classes = FluxImg2ImgVaeEncoderBlocks.values()
|
||||
block_names = FluxImg2ImgVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxVaeEncoderStep]
|
||||
block_classes = [FluxImg2ImgVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@@ -41,45 +61,48 @@ class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is provided, step will be skipped."
|
||||
+ " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img, img2img
|
||||
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
FluxInputStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxSetTimestepsStep,
|
||||
# before_denoise: text2img
|
||||
FluxBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxSetTimestepsStep()),
|
||||
("prepare_rope_inputs", FluxRoPEInputsStep()),
|
||||
]
|
||||
block_names = ["input", "prepare_latents", "set_timesteps"]
|
||||
)
|
||||
|
||||
|
||||
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = FluxBeforeDenoiseBlocks.values()
|
||||
block_names = FluxBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `FluxPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `FluxSetTimestepsStep` is used to set the timesteps\n"
|
||||
)
|
||||
return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
|
||||
|
||||
|
||||
# before_denoise: img2img
|
||||
FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
|
||||
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", FluxRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents"]
|
||||
block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
|
||||
block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
|
||||
)
|
||||
return "Before denoise step that prepare the inputs for the denoise step for img2img task."
|
||||
|
||||
|
||||
# before_denoise: all task (text2img, img2img)
|
||||
@@ -113,7 +136,7 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# decode: all task (text2img, img2img, inpainting)
|
||||
# decode: all task (text2img, img2img)
|
||||
class FluxAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxDecodeStep]
|
||||
block_names = ["non-inpaint"]
|
||||
@@ -124,32 +147,73 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
|
||||
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
|
||||
|
||||
|
||||
# inputs: text2image/img2img
|
||||
FluxImg2ImgBlocks = InsertableDict(
|
||||
[("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())]
|
||||
)
|
||||
|
||||
|
||||
class FluxImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
block_classes = FluxImg2ImgBlocks.values()
|
||||
block_names = FluxImg2ImgBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
class FluxImageAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
|
||||
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
|
||||
+ " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [FluxInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
|
||||
model_name = "flux"
|
||||
block_classes = [FluxImageAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `FluxInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `FluxImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
+ "This step support text-to-image and image-to-image tasks for Flux:\n"
|
||||
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
|
||||
)
|
||||
|
||||
|
||||
# text2image
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
FluxTextEncoderStep,
|
||||
FluxAutoVaeEncoderStep,
|
||||
FluxCoreDenoiseStep,
|
||||
FluxAutoDecodeStep,
|
||||
# Auto blocks (text2image and img2img)
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("image_encoder", FluxAutoVaeEncoderStep()),
|
||||
("denoise", FluxCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
block_names = ["text_encoder", "image_encoder", "denoise", "decode"]
|
||||
)
|
||||
|
||||
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -162,35 +226,28 @@ class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
("set_timesteps", FluxSetTimestepsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("input", FluxTextInputStep()),
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxSetTimestepsStep()),
|
||||
("prepare_rope_inputs", FluxRoPEInputsStep()),
|
||||
("denoise", FluxDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxVaeEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxVaeEncoderDynamicStep()),
|
||||
("input", FluxImg2ImgInputStep()),
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
|
||||
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", FluxRoPEInputsStep()),
|
||||
("denoise", FluxDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxAutoVaeEncoderStep),
|
||||
("denoise", FluxCoreDenoiseStep),
|
||||
("decode", FluxAutoDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
|
||||
@@ -838,6 +838,9 @@ def load_sub_model(
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
|
||||
loading_kwargs.pop("offload_state_dict")
|
||||
|
||||
if (
|
||||
quantization_config is not None
|
||||
and isinstance(quantization_config, PipelineQuantizationConfig)
|
||||
|
||||
@@ -21,6 +21,7 @@ import operator as op
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict, defaultdict
|
||||
from functools import lru_cache as cache
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Any, Tuple, Union
|
||||
@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
|
||||
|
||||
|
||||
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
|
||||
@cache
|
||||
def is_torch_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current PyTorch version to a given reference with an operation.
|
||||
@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
|
||||
return compare_versions(parse(_torch_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_torch_xla_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current torch_xla version to a given reference with an operation.
|
||||
@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
|
||||
return compare_versions(parse(_torch_xla_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_transformers_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Transformers version to a given reference with an operation.
|
||||
@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
|
||||
return compare_versions(parse(_transformers_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_hf_hub_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Hugging Face Hub version to a given reference with an operation.
|
||||
@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
|
||||
return compare_versions(parse(_hf_hub_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_accelerate_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
|
||||
return compare_versions(parse(_accelerate_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_peft_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current PEFT version to a given reference with an operation.
|
||||
@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
|
||||
return compare_versions(parse(_peft_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_bitsandbytes_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
|
||||
return compare_versions(parse(_bitsandbytes_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_gguf_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
|
||||
return compare_versions(parse(_gguf_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_torchao_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current torchao version to a given reference with an operation.
|
||||
@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
|
||||
return compare_versions(parse(_torchao_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_k_diffusion_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current k-diffusion version to a given reference with an operation.
|
||||
@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
|
||||
return compare_versions(parse(_k_diffusion_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_optimum_quanto_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_nvidia_modelopt_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Nvidia ModelOpt version to a given reference with an operation.
|
||||
@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
|
||||
return compare_versions(parse(_nvidia_modelopt_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_xformers_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current xformers version to a given reference with an operation.
|
||||
@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
|
||||
return compare_versions(parse(_xformers_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_sageattention_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current sageattention version to a given reference with an operation.
|
||||
@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
|
||||
return compare_versions(parse(_sageattention_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_flash_attn_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current flash-attention version to a given reference with an operation.
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTXPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
|
||||
from ..testing_utils import floats_tensor, require_peft_backend
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = LTXPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"in_channels": 8,
|
||||
"out_channels": 8,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 4,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 32,
|
||||
"num_layers": 1,
|
||||
"caption_channels": 32,
|
||||
}
|
||||
transformer_cls = LTXVideoTransformer3DModel
|
||||
vae_kwargs = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
vae_cls = AutoencoderKLLTXVideo
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
|
||||
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
num_channels = 8
|
||||
num_frames = 9
|
||||
num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
|
||||
latent_height = 8
|
||||
latent_width = 8
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width))
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"num_frames": num_frames,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": sequence_length,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user