Add OmniGen (#10148)

* OmniGen model.py

* update OmniGenTransformerModel

* omnigen pipeline

* omnigen pipeline

* update omnigen_pipeline

* test case for omnigen

* update omnigenpipeline

* update docs

* update docs

* offload_transformer

* enable_transformer_block_cpu_offload

* update docs

* reformat

* reformat

* reformat

* update docs

* update docs

* make style

* make style

* Update docs/source/en/api/models/omnigen_transformer.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* revert changes to examples/

* update OmniGen2DModel

* make style

* update test cases

* Update docs/source/en/api/pipelines/omnigen.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* typo

* Update src/diffusers/models/embeddings.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/models/attention.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Co-authored-by: hlky <hlky@hlky.ac>

* consistent attention processor

* updata

* update

* check_inputs

* make style

* update testpipeline

* update testpipeline

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
Shitao Xiao
2025-02-12 04:46:38 +08:00
committed by GitHub
parent ed4b75229f
commit 798e17187d
20 changed files with 2543 additions and 4 deletions

View File

@@ -89,6 +89,8 @@
title: Kandinsky
- local: using-diffusers/ip_adapter
title: IP-Adapter
- local: using-diffusers/omnigen
title: OmniGen
- local: using-diffusers/pag
title: PAG
- local: using-diffusers/controlnet
@@ -292,6 +294,8 @@
title: LTXVideoTransformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/omnigen_transformer
title: OmniGenTransformer2DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
@@ -448,6 +452,8 @@
title: MultiDiffusion
- local: api/pipelines/musicldm
title: MusicLDM
- local: api/pipelines/omnigen
title: OmniGen
- local: api/pipelines/pag
title: PAG
- local: api/pipelines/paint_by_example

View File

@@ -0,0 +1,19 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# OmniGenTransformer2DModel
A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).
## OmniGenTransformer2DModel
[[autodoc]] OmniGenTransformer2DModel

View File

@@ -0,0 +1,106 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# OmniGen
[OmniGen: Unified Image Generation](https://arxiv.org/pdf/2409.11340) from BAAI, by Shitao Xiao, Yueze Wang, Junjie Zhou, Huaying Yuan, Xingrun Xing, Ruiran Yan, Chaofan Li, Shuting Wang, Tiejun Huang, Zheng Liu.
The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language
generation tasks and revolutionized human-machine interaction.
However, in the realm of image generation, a unified model capable of handling various tasks
within a single framework remains largely unexplored. In
this work, we introduce OmniGen, a new diffusion model
for unified image generation. OmniGen is characterized
by the following features: 1) Unification: OmniGen not
only demonstrates text-to-image generation capabilities but
also inherently supports various downstream tasks, such
as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of
OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion
models, it is more user-friendly and can complete complex
tasks end-to-end through instructions without the need for
extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from
learning in a unified format, OmniGen effectively transfers
knowledge across different tasks, manages unseen tasks and
domains, and exhibits novel capabilities. We also explore
the models reasoning capabilities and potential applications of the chain-of-thought mechanism.
This work represents the first attempt at a general-purpose image generation model,
and we will release our resources at https:
//github.com/VectorSpaceLab/OmniGen to foster future advancements.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
## Inference
First, load the pipeline:
```python
import torch
from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
```
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size.
```py
prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111),
).images[0]
image
```
OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py
prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
image = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
image
```
## OmniGenPipeline
[[autodoc]] OmniGenPipeline
- all
- __call__

View File

@@ -0,0 +1,314 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# OmniGen
OmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is a single model designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation). It has the following features:
- Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images.
- Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text.
For more information, please refer to the [paper](https://arxiv.org/pdf/2409.11340).
This guide will walk you through using OmniGen for various tasks and use cases.
## Load model checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
```py
import torch
from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
```
## Text-to-image
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size.
```py
import torch
from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111),
).images[0]
image
```
<div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png" alt="generated image"/>
</div>
## Image edit
OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
image = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">original image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">edited image</figcaption>
</div>
</div>
OmniGen has some interesting features, such as visual reasoning, as shown in the example below.
```py
prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <img><|image_1|></img>"
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
image
```
<div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/reasoning.png" alt="generated image"/>
</div>
## Controllable generation
OmniGen can handle several classic computer vision tasks.
As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="Detect the skeleton of human in this image: <img><|image_1|></img>"
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image1 = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
image1
prompt="Generate a new photo using the following picture and text as conditions: <img><|image_1|></img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")]
image2 = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
image2
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">original image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">detected skeleton</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal2img.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">skeleton to image</figcaption>
</div>
</div>
OmniGen can also directly use relevant information from input images to generate new images.
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="Following the pose of this image <img><|image_1|></img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/same_pose.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
## ID and object preserving
OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously.
Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>"
input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png")
input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png")
input_images=[input_image_1, input_image_2]
image = pipe(
prompt=prompt,
input_images=input_images,
height=1024,
width=1024,
guidance_scale=2.5,
img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input_image_1</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input_image_2</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/id2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <img><|image_1|></img>. The long-sleeve blouse and a pleated skirt are <img><|image_2|></img>."
input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg")
input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg")
input_images=[input_image_1, input_image_2]
image = pipe(
prompt=prompt,
input_images=input_images,
height=1024,
width=1024,
guidance_scale=2.5,
img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">person image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">clothe image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/tryon.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
## Optimization when inputting multiple images
For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU).
However, when using input images, the computational cost increases.
Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images.
Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `.
In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`.
The memory consumption for different image sizes is shown in the table below:
| Method | Memory Usage |
|---------------------------|--------------|
| max_input_image_size=1024 | 40GB |
| max_input_image_size=512 | 17GB |
| max_input_image_size=256 | 14GB |

View File

@@ -0,0 +1,203 @@
import argparse
import os
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
def main(args):
# checkpoint from https://huggingface.co/Shitao/OmniGen-v1
if not os.path.exists(args.origin_ckpt_path):
print("Model not found, downloading...")
cache_folder = os.getenv("HF_HUB_CACHE")
args.origin_ckpt_path = snapshot_download(
repo_id=args.origin_ckpt_path,
cache_dir=cache_folder,
ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
)
print(f"Downloaded model to {args.origin_ckpt_path}")
ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
ckpt = load_file(ckpt, device="cpu")
mapping_dict = {
"pos_embed": "patch_embedding.pos_embed",
"x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
"x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
"input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
"input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
"final_layer.linear.weight": "proj_out.weight",
"final_layer.linear.bias": "proj_out.bias",
"time_token.mlp.0.weight": "time_token.linear_1.weight",
"time_token.mlp.0.bias": "time_token.linear_1.bias",
"time_token.mlp.2.weight": "time_token.linear_2.weight",
"time_token.mlp.2.bias": "time_token.linear_2.bias",
"t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
"t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
"t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
"t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
"llm.embed_tokens.weight": "embed_tokens.weight",
}
converted_state_dict = {}
for k, v in ckpt.items():
if k in mapping_dict:
converted_state_dict[mapping_dict[k]] = v
elif "qkv" in k:
to_q, to_k, to_v = v.chunk(3)
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
elif "o_proj" in k:
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
else:
converted_state_dict[k[4:]] = v
transformer = OmniGenTransformer2DModel(
rope_scaling={
"long_factor": [
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0799999237060547,
1.2299998998641968,
1.2299998998641968,
1.2999999523162842,
1.4499999284744263,
1.5999999046325684,
1.6499998569488525,
1.8999998569488525,
2.859999895095825,
3.68999981880188,
5.419999599456787,
5.489999771118164,
5.489999771118164,
9.09000015258789,
11.579999923706055,
15.65999984741211,
15.769999504089355,
15.789999961853027,
18.360000610351562,
21.989999771118164,
23.079999923706055,
30.009998321533203,
32.35000228881836,
32.590003967285156,
35.56000518798828,
39.95000457763672,
53.840003967285156,
56.20000457763672,
57.95000457763672,
59.29000473022461,
59.77000427246094,
59.920005798339844,
61.190006256103516,
61.96000671386719,
62.50000762939453,
63.3700065612793,
63.48000717163086,
63.48000717163086,
63.66000747680664,
63.850006103515625,
64.08000946044922,
64.760009765625,
64.80001068115234,
64.81001281738281,
64.81001281738281,
],
"short_factor": [
1.05,
1.05,
1.05,
1.1,
1.1,
1.1,
1.2500000000000002,
1.2500000000000002,
1.4000000000000004,
1.4500000000000004,
1.5500000000000005,
1.8500000000000008,
1.9000000000000008,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.1000000000000005,
2.1000000000000005,
2.2,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3999999999999995,
2.3999999999999995,
2.6499999999999986,
2.6999999999999984,
2.8999999999999977,
2.9499999999999975,
3.049999999999997,
3.049999999999997,
3.049999999999997,
],
"type": "su",
},
patch_size=2,
in_channels=4,
pos_embed_max_size=192,
)
transformer.load_state_dict(converted_state_dict, strict=True)
transformer.to(torch.bfloat16)
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
pipeline.save_pretrained(args.dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--origin_ckpt_path",
default="Shitao/OmniGen-v1",
type=str,
required=False,
help="Path to the checkpoint to convert.",
)
parser.add_argument(
"--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
)
args = parser.parse_args()
main(args)

View File

@@ -124,6 +124,7 @@ else:
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
"OmniGenTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SanaTransformer2DModel",
@@ -342,6 +343,7 @@ else:
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
"OmniGenPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
@@ -638,6 +640,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,
@@ -835,6 +838,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
OmniGenPipeline,
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,

View File

@@ -73,6 +73,7 @@ if is_torch_available():
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -142,6 +143,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,

View File

@@ -71,7 +71,7 @@ class AdaLayerNorm(nn.Module):
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX for now.
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]

View File

@@ -22,5 +22,6 @@ if is_torch_available():
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel

View File

@@ -0,0 +1,699 @@
# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers
from ..attention_processor import Attention, AttentionProcessor
from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class OmniGenFeedForward(nn.Module):
r"""
A feed-forward layer for OmniGen.
Parameters:
hidden_size (`int`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
):
super().__init__()
self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.activation_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states)
class OmniGenPatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for OmniGen."""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 4,
embed_dim: int = 768,
bias: bool = True,
interpolation_scale: float = 1,
pos_embed_max_size: int = 192,
base_size: int = 64,
):
super().__init__()
self.output_image_proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.input_image_proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.patch_size = patch_size
self.interpolation_scale = interpolation_scale
self.pos_embed_max_size = pos_embed_max_size
pos_embed = get_2d_sincos_pos_embed(
embed_dim,
self.pos_embed_max_size,
base_size=base_size,
interpolation_scale=self.interpolation_scale,
output_type="pt",
)
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
def cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
height = height // self.patch_size
width = width // self.patch_size
if height > self.pos_embed_max_size:
raise ValueError(
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if width > self.pos_embed_max_size:
raise ValueError(
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def patch_embeddings(self, latent, is_input_image: bool):
if is_input_image:
latent = self.input_image_proj(latent)
else:
latent = self.output_image_proj(latent)
latent = latent.flatten(2).transpose(1, 2)
return latent
def forward(self, latent: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None):
"""
Args:
latent: encoded image latents
is_input_image: use input_image_proj or output_image_proj
padding_latent:
When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence
length.
Returns: torch.Tensor
"""
if isinstance(latent, list):
if padding_latent is None:
padding_latent = [None] * len(latent)
patched_latents = []
for sub_latent, padding in zip(latent, padding_latent):
height, width = sub_latent.shape[-2:]
sub_latent = self.patch_embeddings(sub_latent, is_input_image)
pos_embed = self.cropped_pos_embed(height, width)
sub_latent = sub_latent + pos_embed
if padding is not None:
sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
patched_latents.append(sub_latent)
else:
height, width = latent.shape[-2:]
pos_embed = self.cropped_pos_embed(height, width)
latent = self.patch_embeddings(latent, is_input_image)
patched_latents = latent + pos_embed
return patched_latents
class OmniGenSuScaledRotaryEmbedding(nn.Module):
def __init__(
self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
self.short_factor = rope_scaling["short_factor"]
self.long_factor = rope_scaling["long_factor"]
self.original_max_position_embeddings = original_max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
if len(cos.shape) == 2:
cos = cos[None, None]
sin = sin[None, None]
elif len(cos.shape) == 3:
cos = cos[:, None]
sin = sin[:, None]
cos, sin = cos.to(x.device), sin.to(x.device)
# Rotates half the hidden dims of the input. this rorate function is widely used in LLM, e.g. Llama, Phi3, etc.
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x_rotated = torch.cat((-x2, x1), dim=-1)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
class OmniGenAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the OmniGen model.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
bsz, q_len, query_dim = query.size()
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query, key = query.to(dtype), key.to(dtype)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).to(dtype)
hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim)
hidden_states = attn.to_out[0](hidden_states)
return hidden_states
class OmniGenBlock(nn.Module):
"""
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
Parameters:
hidden_size (`int`): Embedding dimension of the input features.
num_attention_heads (`int`): Number of attention heads.
num_key_value_heads (`int`):
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
intermediate_size (`int`): size of intermediate layer.
rms_norm_eps (`float`): The eps for norm layer.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
intermediate_size: int,
rms_norm_eps: float,
) -> None:
super().__init__()
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.self_attn = Attention(
query_dim=hidden_size,
cross_attention_dim=hidden_size,
dim_head=hidden_size // num_attention_heads,
heads=num_attention_heads,
kv_heads=num_key_value_heads,
bias=False,
out_dim=hidden_size,
out_bias=False,
processor=OmniGenAttnProcessor2_0(),
)
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.mlp = OmniGenFeedForward(hidden_size, intermediate_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
):
"""
Perform a forward pass through the LuminaNextDiTBlock.
Parameters:
hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_outputs = self.self_attn(
hidden_states=hidden_states,
encoder_hidden_states=hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
The Transformer model introduced in OmniGen.
Reference: https://arxiv.org/pdf/2409.11340
Parameters:
hidden_size (`int`, *optional*, defaults to 3072):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
rms_norm_eps (`float`, *optional*, defaults to 1e-5): eps for RMSNorm layer.
num_attention_heads (`int`, *optional*, defaults to 32):
The number of attention heads in each attention layer. This parameter specifies how many separate attention
mechanisms are used.
num_kv_heads (`int`, *optional*, defaults to 32):
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
If None, it defaults to num_attention_heads.
intermediate_size (`int`, *optional*, defaults to 8192): dimension of the intermediate layer in FFN
num_layers (`int`, *optional*, default to 32):
The number of layers in the model. This defines the depth of the neural network.
pad_token_id (`int`, *optional*, default to 32000):
id for pad token
vocab_size (`int`, *optional*, default to 32064):
size of vocabulary
patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["OmniGenBlock"]
@register_to_config
def __init__(
self,
hidden_size: int = 3072,
rms_norm_eps: float = 1e-05,
num_attention_heads: int = 32,
num_key_value_heads: int = 32,
intermediate_size: int = 8192,
num_layers: int = 32,
pad_token_id: int = 32000,
vocab_size: int = 32064,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
rope_base: int = 10000,
rope_scaling: Dict = None,
patch_size=2,
in_channels=4,
pos_embed_max_size: int = 192,
time_step_dim: int = 256,
flip_sin_to_cos: bool = True,
downscale_freq_shift: int = 0,
timestep_activation_fn: str = "silu",
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.pos_embed_max_size = pos_embed_max_size
self.patch_embedding = OmniGenPatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=hidden_size,
pos_embed_max_size=pos_embed_max_size,
)
self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift)
self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
self.rotary_emb = OmniGenSuScaledRotaryEmbedding(
hidden_size // num_attention_heads,
max_position_embeddings=max_position_embeddings,
original_max_position_embeddings=original_max_position_embeddings,
base=rope_base,
rope_scaling=rope_scaling,
)
self.layers = nn.ModuleList(
[
OmniGenBlock(
hidden_size,
num_attention_heads,
num_key_value_heads,
intermediate_size,
rms_norm_eps,
)
for _ in range(num_layers)
]
)
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.gradient_checkpointing = False
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C) imgs: (N, H, W, C)
"""
c = self.out_channels
x = x.reshape(
shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c)
)
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h, w))
return imgs
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[OmniGenAttnProcessor2_0, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def get_multimodal_embeddings(
self,
input_ids: torch.Tensor,
input_img_latents: List[torch.Tensor],
input_image_sizes: Dict,
):
"""
get the multi-modal conditional embeddings
Args:
input_ids: a sequence of text id
input_img_latents: continues embedding of input images
input_image_sizes: the index of the input image in the input_ids sequence.
Returns: torch.Tensor
"""
input_img_latents = [x.to(self.dtype) for x in input_img_latents]
condition_tokens = None
if input_ids is not None:
condition_tokens = self.embed_tokens(input_ids)
input_img_inx = 0
if input_img_latents is not None:
input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
# replace the placeholder in text tokens with the image embedding.
condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
condition_tokens.dtype
)
input_img_inx += 1
return condition_tokens
def forward(
self,
hidden_states: torch.Tensor,
timestep: Union[int, float, torch.FloatTensor],
input_ids: torch.Tensor,
input_img_latents: List[torch.Tensor],
input_image_sizes: Dict[int, List[int]],
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
"""
The [`OmniGenTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
timestep (`torch.FloatTensor`):
Used to indicate denoising step.
input_ids (`torch.LongTensor`):
token ids
input_img_latents (`torch.Tensor`):
encoded image latents by VAE
input_image_sizes (`dict`):
the indices of the input_img_latents in the input_ids
attention_mask (`torch.Tensor`):
mask for self-attention
position_ids (`torch.LongTensor`):
id to represent position
past_key_values (`transformers.cache_utils.Cache`):
previous key and value states
offload_transformer_block (`bool`, *optional*, defaults to `True`):
offload transformer block to cpu
attention_kwargs: (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain tuple.
Returns:
If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a `tuple` where the first
element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.size()[-2:]
hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
num_tokens_for_output_image = hidden_states.size(1)
time_token = self.time_token(self.time_proj(timestep).to(hidden_states.dtype)).unsqueeze(1)
condition_tokens = self.get_multimodal_embeddings(
input_ids=input_ids,
input_img_latents=input_img_latents,
input_image_sizes=input_image_sizes,
)
if condition_tokens is not None:
inputs_embeds = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
else:
inputs_embeds = torch.cat([time_token, hidden_states], dim=1)
batch_size, seq_length = inputs_embeds.shape[:2]
position_ids = position_ids.view(-1, seq_length).long()
if attention_mask is not None and attention_mask.dim() == 3:
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
else:
raise Exception("attention_mask parameter was unavailable or invalid")
hidden_states = inputs_embeds
image_rotary_emb = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
decoder_layer, hidden_states, attention_mask, image_rotary_emb
)
else:
hidden_states = decoder_layer(
hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb
)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, -num_tokens_for_output_image:]
timestep_proj = self.time_proj(timestep)
temb = self.t_embedder(timestep_proj.type_as(hidden_states))
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)
output = self.unpatchify(hidden_states, height, width)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -264,6 +264,7 @@ else:
)
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -602,6 +603,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .mochi import MochiPipeline
from .musicldm import MusicLDMPipeline
from .omnigen import OmniGenPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,

View File

@@ -48,9 +48,14 @@ EXAMPLE_DOC_STRING = """
>>> from huggingface_hub import snapshot_download
>>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
>>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = (
... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
... )
>>> (
... face_helper_1,
... face_helper_2,
... face_clip_model,
... face_main_model,
... eva_transform_mean,
... eva_transform_std,
... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
>>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")

View File

@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_omnigen"] = ["OmniGenPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_omnigen import OmniGenPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,530 @@
# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import LlamaTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGenTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .processor_omnigen import OmniGenMultiModalProcessor
if is_torch_xla_available():
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import OmniGenPipeline
>>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
>>> image.save("t2i.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class OmniGenPipeline(
DiffusionPipeline,
):
r"""
The OmniGen pipeline for multimodal-to-image generation.
Reference: https://arxiv.org/pdf/2409.11340
Args:
transformer ([`OmniGenTransformer2DModel`]):
Autoregressive Transformer architecture for OmniGen.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
tokenizer (`LlamaTokenizer`):
Text tokenizer of class.
[LlamaTokenizer](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaTokenizer).
"""
model_cpu_offload_seq = "transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents"]
def __init__(
self,
transformer: OmniGenTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
tokenizer: LlamaTokenizer,
):
super().__init__()
self.register_modules(
vae=vae,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) is not None else 8
)
# OmniGen latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.multimodal_processor = OmniGenMultiModalProcessor(tokenizer, max_image_size=1024)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 120000
)
self.default_sample_size = 128
def encode_input_images(
self,
input_pixel_values: List[torch.Tensor],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
get the continue embedding of input images by VAE
Args:
input_pixel_values: normlized pixel of input images
device:
Returns: torch.Tensor
"""
device = device or self._execution_device
dtype = dtype or self.vae.dtype
input_img_latents = []
for img in input_pixel_values:
img = self.vae.encode(img.to(device, dtype)).latent_dist.sample().mul_(self.vae.config.scaling_factor)
input_img_latents.append(img)
return input_img_latents
def check_inputs(
self,
prompt,
input_images,
height,
width,
use_input_image_size_as_output,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if input_images is not None:
if len(input_images) != len(prompt):
raise ValueError(
f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}."
)
for i in range(len(input_images)):
if input_images[i] is not None:
if not all(f"<img><|image_{k + 1}|></img>" in prompt[i] for k in range(len(input_images[i]))):
raise ValueError(
f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`"
)
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if use_input_image_size_as_output:
if input_images is None or input_images[0] is None:
raise ValueError(
"`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
input_images: Union[PipelineImageInput, List[PipelineImageInput]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
max_input_image_size: int = 1024,
timesteps: List[int] = None,
guidance_scale: float = 2.5,
img_guidance_scale: float = 1.6,
use_input_image_size_as_output: bool = False,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 120000,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If the input includes images, need to add
placeholders `<img><|image_i|></img>` in the prompt to indicate the position of the i-th images.
input_images (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
max_input_image_size (`int`, *optional*, defaults to 1024):
the maximum size of input image, which will be used to crop the input image to the maximum size
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 2.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
img_guidance_scale (`float`, *optional*, defaults to 1.6):
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
use_input_image_size_as_output (bool, defaults to False):
whether to use the input image size as the output image size, which can be used for single-image input,
e.g., image editing task
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
num_cfg = 2 if input_images is not None else 1
use_img_cfg = True if input_images is not None else False
if isinstance(prompt, str):
prompt = [prompt]
input_images = [input_images]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
input_images,
height,
width,
use_input_image_size_as_output,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Define call parameters
batch_size = len(prompt)
device = self._execution_device
# 3. process multi-modal instructions
if max_input_image_size != self.multimodal_processor.max_image_size:
self.multimodal_processor.reset_max_image_size(max_image_size=max_input_image_size)
processed_data = self.multimodal_processor(
prompt,
input_images,
height=height,
width=width,
use_img_cfg=use_img_cfg,
use_input_image_size_as_output=use_input_image_size_as_output,
num_images_per_prompt=num_images_per_prompt,
)
processed_data["input_ids"] = processed_data["input_ids"].to(device)
processed_data["attention_mask"] = processed_data["attention_mask"].to(device)
processed_data["position_ids"] = processed_data["position_ids"].to(device)
# 4. Encode input images
input_img_latents = self.encode_input_images(processed_data["input_pixel_values"], device=device)
# 5. Prepare timesteps
sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps]
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
self._num_timesteps = len(timesteps)
# 6. Prepare latents.
if use_input_image_size_as_output:
height, width = processed_data["input_pixel_values"][0].shape[-2:]
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
self.transformer.dtype,
device,
generator,
latents,
)
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (num_cfg + 1))
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
input_ids=processed_data["input_ids"],
input_img_latents=input_img_latents,
input_image_sizes=processed_data["input_image_sizes"],
attention_mask=processed_data["attention_mask"],
position_ids=processed_data["position_ids"],
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if num_cfg == 2:
cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0)
noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond)
else:
cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0)
noise_pred = uncond + guidance_scale * (cond - uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
progress_bar.update()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
else:
image = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -0,0 +1,327 @@
# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, List
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
def crop_image(pil_image, max_image_size):
"""
Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and
width are multiples of 16.
"""
while min(*pil_image.size) >= 2 * max_image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
if max(*pil_image.size) > max_image_size:
scale = max_image_size / max(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
if min(*pil_image.size) < 16:
scale = 16 / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y1 = (arr.shape[0] % 16) // 2
crop_y2 = arr.shape[0] % 16 - crop_y1
crop_x1 = (arr.shape[1] % 16) // 2
crop_x2 = arr.shape[1] % 16 - crop_x1
arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2]
return Image.fromarray(arr)
class OmniGenMultiModalProcessor:
def __init__(self, text_tokenizer, max_image_size: int = 1024):
self.text_tokenizer = text_tokenizer
self.max_image_size = max_image_size
self.image_transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
self.collator = OmniGenCollator()
def reset_max_image_size(self, max_image_size):
self.max_image_size = max_image_size
self.image_transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
def process_image(self, image):
if isinstance(image, str):
image = Image.open(image).convert("RGB")
return self.image_transform(image)
def process_multi_modal_prompt(self, text, input_images):
text = self.add_prefix_instruction(text)
if input_images is None or len(input_images) == 0:
model_inputs = self.text_tokenizer(text)
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
pattern = r"<\|image_\d+\|>"
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
for i in range(1, len(prompt_chunks)):
if prompt_chunks[i][0] == 1:
prompt_chunks[i] = prompt_chunks[i][1:]
image_tags = re.findall(pattern, text)
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(set(image_ids))
assert unique_image_ids == list(
range(1, len(unique_image_ids) + 1)
), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
# total images must be the same as the number of image tags
assert (
len(unique_image_ids) == len(input_images)
), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
input_images = [input_images[x - 1] for x in image_ids]
all_input_ids = []
img_inx = []
for i in range(len(prompt_chunks)):
all_input_ids.extend(prompt_chunks[i])
if i != len(prompt_chunks) - 1:
start_inx = len(all_input_ids)
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
img_inx.append([start_inx, start_inx + size])
all_input_ids.extend([0] * size)
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
def add_prefix_instruction(self, prompt):
user_prompt = "<|user|>\n"
generation_prompt = "Generate an image according to the following instructions\n"
assistant_prompt = "<|assistant|>\n<|diffusion|>"
prompt_suffix = "<|end|>\n"
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
return prompt
def __call__(
self,
instructions: List[str],
input_images: List[List[str]] = None,
height: int = 1024,
width: int = 1024,
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
use_img_cfg: bool = True,
separate_cfg_input: bool = False,
use_input_image_size_as_output: bool = False,
num_images_per_prompt: int = 1,
) -> Dict:
if isinstance(instructions, str):
instructions = [instructions]
input_images = [input_images]
input_data = []
for i in range(len(instructions)):
cur_instruction = instructions[i]
cur_input_images = None if input_images is None else input_images[i]
if cur_input_images is not None and len(cur_input_images) > 0:
cur_input_images = [self.process_image(x) for x in cur_input_images]
else:
cur_input_images = None
assert "<img><|image_1|></img>" not in cur_instruction
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
neg_mllm_input, img_cfg_mllm_input = None, None
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
if use_img_cfg:
if cur_input_images is not None and len(cur_input_images) >= 1:
img_cfg_prompt = [f"<img><|image_{i + 1}|></img>" for i in range(len(cur_input_images))]
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
else:
img_cfg_mllm_input = neg_mllm_input
for _ in range(num_images_per_prompt):
if use_input_image_size_as_output:
input_data.append(
(
mllm_input,
neg_mllm_input,
img_cfg_mllm_input,
[mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
)
)
else:
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
return self.collator(input_data)
class OmniGenCollator:
def __init__(self, pad_token_id=2, hidden_size=3072):
self.pad_token_id = pad_token_id
self.hidden_size = hidden_size
def create_position(self, attention_mask, num_tokens_for_output_images):
position_ids = []
text_length = attention_mask.size(-1)
img_length = max(num_tokens_for_output_images)
for mask in attention_mask:
temp_l = torch.sum(mask)
temp_position = [0] * (text_length - temp_l) + list(
range(temp_l + img_length + 1)
) # we add a time embedding into the sequence, so add one more token
position_ids.append(temp_position)
return torch.LongTensor(position_ids)
def create_mask(self, attention_mask, num_tokens_for_output_images):
"""
OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within
each image sequence References: [OmniGen](https://arxiv.org/pdf/2409.11340)
"""
extended_mask = []
padding_images = []
text_length = attention_mask.size(-1)
img_length = max(num_tokens_for_output_images)
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
inx = 0
for mask in attention_mask:
temp_l = torch.sum(mask)
pad_l = text_length - temp_l
temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
image_mask = torch.zeros(size=(temp_l + 1, img_length))
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
if pad_l > 0:
pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
pad_mask = torch.ones(size=(pad_l, seq_len))
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
true_img_length = num_tokens_for_output_images[inx]
pad_img_length = img_length - true_img_length
if pad_img_length > 0:
temp_mask[:, -pad_img_length:] = 0
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
else:
temp_padding_imgs = None
extended_mask.append(temp_mask.unsqueeze(0))
padding_images.append(temp_padding_imgs)
inx += 1
return torch.cat(extended_mask, dim=0), padding_images
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
for b_inx in image_sizes.keys():
for start_inx, end_inx in image_sizes[b_inx]:
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
return attention_mask
def pad_input_ids(self, input_ids, image_sizes):
max_l = max([len(x) for x in input_ids])
padded_ids = []
attention_mask = []
for i in range(len(input_ids)):
temp_ids = input_ids[i]
temp_l = len(temp_ids)
pad_l = max_l - temp_l
if pad_l == 0:
attention_mask.append([1] * max_l)
padded_ids.append(temp_ids)
else:
attention_mask.append([0] * pad_l + [1] * temp_l)
padded_ids.append([self.pad_token_id] * pad_l + temp_ids)
if i in image_sizes:
new_inx = []
for old_inx in image_sizes[i]:
new_inx.append([x + pad_l for x in old_inx])
image_sizes[i] = new_inx
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
def process_mllm_input(self, mllm_inputs, target_img_size):
num_tokens_for_output_images = []
for img_size in target_img_size:
num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
pixel_values, image_sizes = [], {}
b_inx = 0
for x in mllm_inputs:
if x["pixel_values"] is not None:
pixel_values.extend(x["pixel_values"])
for size in x["image_sizes"]:
if b_inx not in image_sizes:
image_sizes[b_inx] = [size]
else:
image_sizes[b_inx].append(size)
b_inx += 1
pixel_values = [x.unsqueeze(0) for x in pixel_values]
input_ids = [x["input_ids"] for x in mllm_inputs]
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
def __call__(self, features):
mllm_inputs = [f[0] for f in features]
cfg_mllm_inputs = [f[1] for f in features]
img_cfg_mllm_input = [f[2] for f in features]
target_img_size = [f[3] for f in features]
if img_cfg_mllm_input[0] is not None:
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
target_img_size = target_img_size + target_img_size + target_img_size
else:
mllm_inputs = mllm_inputs + cfg_mllm_inputs
target_img_size = target_img_size + target_img_size
(
all_padded_input_ids,
all_position_ids,
all_attention_mask,
all_padding_images,
all_pixel_values,
all_image_sizes,
) = self.process_mllm_input(mllm_inputs, target_img_size)
data = {
"input_ids": all_padded_input_ids,
"attention_mask": all_attention_mask,
"position_ids": all_position_ids,
"input_pixel_values": all_pixel_values,
"input_image_sizes": all_image_sizes,
}
return data

View File

@@ -621,6 +621,21 @@ class MultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class OmniGenTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1217,6 +1217,21 @@ class MusicLDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class OmniGenPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class PaintByExamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import OmniGenTransformer2DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = OmniGenTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
sequence_length = 24
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
timestep = torch.rand(size=(batch_size,), dtype=hidden_states.dtype).to(torch_device)
input_ids = torch.randint(0, 10, (batch_size, sequence_length)).to(torch_device)
input_img_latents = [torch.randn((1, num_channels, height, width)).to(torch_device)]
input_image_sizes = {0: [[0, 0 + height * width // 2 // 2]]}
attn_seq_length = sequence_length + 1 + height * width // 2 // 2
attention_mask = torch.ones((batch_size, attn_seq_length, attn_seq_length)).to(torch_device)
position_ids = torch.LongTensor([list(range(attn_seq_length))] * batch_size).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"input_ids": input_ids,
"input_img_latents": input_img_latents,
"input_image_sizes": input_image_sizes,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
@property
def input_shape(self):
return (4, 8, 8)
@property
def output_shape(self):
return (4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"hidden_size": 16,
"num_attention_heads": 4,
"num_key_value_heads": 4,
"intermediate_size": 32,
"num_layers": 1,
"pad_token_id": 0,
"vocab_size": 100,
"in_channels": 4,
"time_step_dim": 4,
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"OmniGenTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

View File

@@ -0,0 +1,153 @@
import gc
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = OmniGenPipeline
params = frozenset(
[
"prompt",
"guidance_scale",
]
)
batch_params = frozenset(
[
"prompt",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
transformer = OmniGenTransformer2DModel(
hidden_size=16,
num_attention_heads=4,
num_key_value_heads=4,
intermediate_size=32,
num_layers=1,
in_channels=4,
time_step_dim=4,
rope_scaling={"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
)
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4, 4, 4, 4),
layers_per_block=1,
latent_channels=4,
norm_num_groups=1,
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
)
scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 1,
"guidance_scale": 3.0,
"output_type": "np",
"height": 16,
"width": 16,
}
return inputs
def test_inference(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
generated_image = pipe(**inputs).images[0]
self.assertEqual(generated_image.shape, (16, 16, 3))
@slow
@require_torch_gpu
class OmniGenPipelineSlowTests(unittest.TestCase):
pipeline_class = OmniGenPipeline
repo_id = "shitao/OmniGen-v1-diffusers"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
return {
"prompt": "A photo of a cat",
"num_inference_steps": 2,
"guidance_scale": 2.5,
"output_type": "np",
"generator": generator,
}
def test_omnigen_inference(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
[0.1783447, 0.16772744, 0.14339337],
[0.17066911, 0.15521264, 0.13757327],
[0.17072496, 0.15531206, 0.13524258],
[0.16746324, 0.1564025, 0.13794944],
[0.16490817, 0.15258026, 0.13697758],
[0.16971767, 0.15826806, 0.13928896],
[0.16782972, 0.15547255, 0.13783783],
[0.16464645, 0.15281534, 0.13522372],
[0.16535294, 0.15301755, 0.13526791],
[0.16365296, 0.15092957, 0.13443318],
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4