Compare commits

..

56 Commits

Author SHA1 Message Date
sayakpaul
c92cf75dfe up 2025-12-12 13:41:01 +05:30
Wang, Yi
218b17040f support CP in native flash attention (#12829)
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-12-12 13:40:50 +05:30
Sayak Paul
a7c7a270f6 [lora] Remove lora docs unneeded and add " # Copied from ..." (#12824)
* remove unneeded docs on load_lora_weights().

* remove more.

* up[

* up

* up
2025-12-12 13:40:50 +05:30
Sayak Paul
5455dd58e8 Update distributed_inference.md to correct syntax (#12827) 2025-12-12 13:40:50 +05:30
Sayak Paul
07084ef036 post release 0.36.0 (#12804)
* post release 0.36.0

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-12-12 13:40:50 +05:30
Sayak Paul
9e15c576cb [docs] improve distributed inference cp docs. (#12810)
* improve distributed inference cp docs.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-12-12 13:40:50 +05:30
Dhruv Nair
8ddba1e082 [WIP] Add Flux2 modular (#12763)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
2025-12-12 13:40:50 +05:30
Sayak Paul
d1b8202e42 Fix Qwen Edit Plus modular for multi-image input (#12601)
* try to fix qwen edit plus multi images (modular)

* up

* up

* test

* up

* up
2025-12-12 13:40:49 +05:30
YiYi Xu
f7439c30c9 [Modular]z-image (#12808)
* initiL

* up up

* fix: z_image -> z-image

* style

* copy

* fix more

* some docstring fix
2025-12-12 13:40:49 +05:30
David El Malih
b53bd8372b Improve docstrings and type hints in scheduling_dpmsolver_singlestep.py (#12798)
feat: add flow sigmas, dynamic shifting, and refine type hints in DPMSolverSinglestepScheduler
2025-12-12 13:40:49 +05:30
David Lacalle Castillo
a73981fe17 [PRX] Improve model compilation (#12787)
* Reimplement img2seq & seq2img in PRX to enable ONNX build without Col2Im (incompatible with TensorRT).

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-12-12 13:40:49 +05:30
Sayak Paul
d738ec4141 Merge branch 'main' into control-lora 2025-12-08 18:32:45 +08:00
Sayak Paul
03d1751cce Merge branch 'main' into control-lora 2025-12-05 21:55:26 +08:00
lavinal712
cd71418052 add doc 2025-12-05 10:24:20 +08:00
lavinal712
58559ecc7e no need modify as peft updated 2025-12-05 10:01:11 +08:00
Yuqian Hong
48eeeae1f7 Merge branch 'main' into control-lora 2025-12-05 09:44:22 +08:00
Yuqian Hong
2223722e5b Merge branch 'huggingface:main' into control-lora 2025-11-28 08:51:59 +08:00
lavinal712
4d1e8912d6 rename 2025-11-22 09:50:45 +08:00
Yuqian Hong
dfad05625e Merge branch 'huggingface:main' into control-lora 2025-11-22 09:23:09 +08:00
Yuqian Hong
9d94c377ef Merge branch 'huggingface:main' into control-lora 2025-09-22 11:44:52 +08:00
Yuqian Hong
1e8221ce39 Add files via upload 2025-08-20 21:23:22 +08:00
Yuqian Hong
00a26cd8dd Create control_lora.py 2025-08-20 21:23:04 +08:00
Yuqian Hong
a2eff1c668 Merge branch 'huggingface:main' into control-lora 2025-08-20 09:23:49 +08:00
Yuqian Hong
1c902725b0 Merge branch 'huggingface:main' into control-lora 2025-08-19 06:41:43 +08:00
Yuqian Hong
59a42b23d3 Merge branch 'huggingface:main' into control-lora 2025-08-17 16:32:42 +08:00
Yuqian Hong
4a64d64407 Merge branch 'main' into control-lora 2025-08-14 11:51:09 +08:00
Yuqian Hong
c6c13b6717 Merge branch 'huggingface:main' into control-lora 2025-08-09 00:24:18 +08:00
Yuqian Hong
af8255e934 Merge branch 'main' into control-lora 2025-07-30 15:48:15 +08:00
Yuqian Hong
d3a07558cf Merge branch 'main' into control-lora 2025-07-21 16:00:43 +08:00
lavinal712
23cba1804f fix alpha 2025-07-05 08:18:34 +00:00
lavinal712
53a06cc969 delete state_dict print 2025-07-05 07:52:01 +00:00
lavinal712
0a5bd74931 1 2025-07-05 05:12:57 +00:00
Yuqian Hong
d752992831 Merge branch 'huggingface:main' into control-lora 2025-07-05 09:53:39 +08:00
Yuqian Hong
39e9254208 Merge branch 'huggingface:main' into control-lora 2025-07-02 17:48:25 +08:00
lavinal712
c134bca767 change peft.py 2025-05-29 14:24:16 +00:00
lavinal712
63bafc88cd change peft.py 2025-05-29 14:23:41 +00:00
Yuqian Hong
8f7fc0ada0 Merge branch 'huggingface:main' into control-lora 2025-05-29 21:59:42 +08:00
lavinal712
6fff794e59 merged but bug 2025-04-09 07:56:40 +00:00
Yuqian Hong
ab9eeff757 Merge branch 'main' into control-lora 2025-04-09 15:41:28 +08:00
lavinal712
6a1ff82d08 resolve conflits 2025-04-09 07:41:14 +00:00
Yuqian Hong
ce2b34bba7 Merge branch 'main' into control-lora 2025-03-26 10:07:45 +08:00
Sayak Paul
2de1505e6e Merge branch 'main' into control-lora 2025-03-25 18:58:07 +01:00
lavinal712
81eed41b74 delete json print 2025-03-23 10:29:08 +00:00
lavinal712
0719c20f5e fix module_to_save bug 2025-03-23 10:27:40 +00:00
lavinal712
7c25a06591 fix PeftAdapterMixin 2025-03-23 09:36:00 +00:00
Yuqian Hong
280cf7fd38 Merge branch 'huggingface:main' into control-lora 2025-03-23 13:30:27 +08:00
Yuqian Hong
33288e667f Merge branch 'huggingface:main' into control-lora 2025-03-17 11:38:42 +08:00
Yuqian Hong
dd24464065 Merge branch 'huggingface:main' into control-lora 2025-02-23 23:14:08 +08:00
lavinal712
523967f396 1 2025-02-15 13:45:51 +00:00
lavinal712
10daac7e19 1 2025-02-15 13:41:30 +00:00
lavinal712
de61226385 1 2025-02-15 13:38:02 +00:00
lavinal712
39b3b84acc add control-lora 2025-02-07 18:43:19 +00:00
lavinal712
2453e149d2 1 2025-02-07 10:24:24 +00:00
lavinal712
9cf8ad7a73 test 2025-02-04 16:40:59 +00:00
lavinal712
e9d91e156d cannot load lora adapter 2025-02-01 01:29:19 +00:00
lavinal712
18de3adad1 run control-lora on diffusers 2025-01-30 03:00:59 +00:00
19 changed files with 43 additions and 2827 deletions

View File

@@ -365,8 +365,6 @@
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/lumina2_transformer2d
@@ -404,7 +402,7 @@
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
- local: api/models/z_image_transformer2d
title: ZImageTransformer2DModel
title: ZImageTransformer2DModel
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
@@ -565,8 +563,6 @@
title: Latent Diffusion
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/longcat_image
title: LongCat-Image
- local: api/pipelines/lumina2
title: Lumina 2.0
- local: api/pipelines/lumina

View File

@@ -1,25 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LongCatImageTransformer2DModel
The model can be loaded with the following code snippet.
```python
from diffusers import LongCatImageTransformer2DModel
transformer = LongCatImageTransformer2DModel.from_pretrained("meituan-longcat/LongCat-Image ", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## LongCatImageTransformer2DModel
[[autodoc]] LongCatImageTransformer2DModel

View File

@@ -1,114 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LongCat-Image
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
We introduce LongCat-Image, a pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models.
### Key Features
- 🌟 **Exceptional Efficiency and Performance**: With only **6B parameters**, LongCat-Image surpasses numerous open-source models that are several times larger across multiple benchmarks, demonstrating the immense potential of efficient model design.
- 🌟 **Superior Editing Performance**: LongCat-Image-Edit model achieves state-of-the-art performance among open-source models, delivering leading instruction-following and image quality with superior visual consistency.
- 🌟 **Powerful Chinese Text Rendering**: LongCat-Image demonstrates superior accuracy and stability in rendering common Chinese characters compared to existing SOTA open-source models and achieves industry-leading coverage of the Chinese dictionary.
- 🌟 **Remarkable Photorealism**: Through an innovative data strategy and training framework, LongCat-Image achieves remarkable photorealism in generated images.
- 🌟 **Comprehensive Open-Source Ecosystem**: We provide a complete toolchain, from intermediate checkpoints to full training code, significantly lowering the barrier for further research and development.
For more details, please refer to the comprehensive [***LongCat-Image Technical Report***](https://arxiv.org/abs/2412.11963)
## Usage Example
```py
import torch
import diffusers
from diffusers import LongCatImagePipeline
weight_dtype = torch.bfloat16
pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16 )
pipe.to('cuda')
# pipe.enable_model_cpu_offload()
prompt = '一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。'
image = pipe(
prompt,
height=768,
width=1344,
guidance_scale=4.0,
num_inference_steps=50,
num_images_per_prompt=1,
generator=torch.Generator("cpu").manual_seed(43),
enable_cfg_renorm=True,
enable_prompt_rewrite=True,
).images[0]
image.save(f'./longcat_image_t2i_example.png')
```
This pipeline was contributed by LongCat-Image Team. The original codebase can be found [here](https://github.com/meituan-longcat/LongCat-Image).
Available models:
<div style="overflow-x: auto; margin-bottom: 16px;">
<table style="border-collapse: collapse; width: 100%;">
<thead>
<tr>
<th style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;">Models</th>
<th style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;">Type</th>
<th style="padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;">Description</th>
<th style="padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;">Download Link</th>
</tr>
</thead>
<tbody>
<tr>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">LongCat&#8209;Image</td>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">Text&#8209;to&#8209;Image</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">Final Release. The standard model for out&#8209;of&#8209;the&#8209;box inference.</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">
<span style="white-space: nowrap;">🤗&nbsp;<a href="https://huggingface.co/meituan-longcat/LongCat-Image">Huggingface</a></span>
</td>
</tr>
<tr>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">LongCat&#8209;Image&#8209;Dev</td>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">Text&#8209;to&#8209;Image</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">Development. Mid-training checkpoint, suitable for fine-tuning.</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">
<span style="white-space: nowrap;">🤗&nbsp;<a href="https://huggingface.co/meituan-longcat/LongCat-Image-Dev">Huggingface</a></span>
</td>
</tr>
<tr>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">LongCat&#8209;Image&#8209;Edit</td>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">Image Editing</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">Specialized model for image editing.</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">
<span style="white-space: nowrap;">🤗&nbsp;<a href="https://huggingface.co/meituan-longcat/LongCat-Image-Edit">Huggingface</a></span>
</td>
</tr>
</tbody>
</table>
</div>
## LongCatImagePipeline
[[autodoc]] LongCatImagePipeline
- all
- __call__
## LongCatImagePipelineOutput
[[autodoc]] pipelines.longcat_image.pipeline_output.LongCatImagePipelineOutput

View File

@@ -235,7 +235,6 @@ else:
"Kandinsky3UNet",
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LongCatImageTransformer2DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
@@ -533,8 +532,6 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
@@ -973,7 +970,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3UNet,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
@@ -1241,8 +1237,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,

View File

@@ -101,7 +101,6 @@ if is_torch_available():
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
@@ -209,7 +208,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,

View File

@@ -256,10 +256,6 @@ class _HubKernelConfig:
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
wrapped_forward_attr: Optional[str] = None
wrapped_backward_attr: Optional[str] = None
wrapped_forward_fn: Optional[Callable] = None
wrapped_backward_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
@@ -274,11 +270,7 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -602,39 +594,22 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
if config.kernel_fn is not None:
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1085,231 +1060,6 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
*,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1356,46 +1106,6 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
# ===== Context parallel =====
@@ -1753,7 +1463,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -1767,35 +1477,17 @@ def _flash_attention_hub(
) -> torch.Tensor:
lse = None
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@@ -1938,7 +1630,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -1952,65 +1644,31 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
@_AttentionBackendRegistry.register(
@@ -2559,7 +2217,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -2584,23 +2242,6 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -33,7 +33,6 @@ if is_torch_available():
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel

View File

@@ -1,548 +0,0 @@
# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import is_torch_npu_available, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
class LongCatImageAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "LongCatImageAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class LongCatImageAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = LongCatImageAttnProcessor
_available_processors = [
LongCatImageAttnProcessor,
]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
@maybe_allow_in_graph
class LongCatImageSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
self.attn = LongCatImageAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=LongCatImageAttnProcessor(),
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
class LongCatImageTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = LongCatImageAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=LongCatImageAttnProcessor(),
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class LongCatImagePosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class LongCatImageTimestepEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
return timesteps_emb
class LongCatImageTransformer2DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
CacheMixin,
):
"""
The Transformer model introduced in Longcat-Image.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 3584,
axes_dims_rope: List[int] = [16, 56, 56],
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pooled_projection_dim = pooled_projection_dim
self.pos_embed = LongCatImagePosEmbed(theta=10000, axes_dim=axes_dims_rope)
self.time_embed = LongCatImageTimestepEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
LongCatImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
LongCatImageSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
self.use_checkpoint = [True] * num_layers
self.use_single_checkpoint = [True] * num_single_layers
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
temb = self.time_embed(timestep, hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -14,7 +14,6 @@
import functools
import math
from math import prod
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -364,13 +363,7 @@ class QwenDoubleStreamAttnProcessor2_0:
@maybe_allow_in_graph
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
zero_cond_t: bool = False,
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
@@ -410,43 +403,10 @@ class QwenImageTransformerBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.zero_cond_t = zero_cond_t
def _modulate(self, x, mod_params, index=None):
def _modulate(self, x, mod_params):
"""Apply modulation to input tensor"""
# x: b l d, shift: b d, scale: b d, gate: b d
shift, scale, gate = mod_params.chunk(3, dim=-1)
if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
@@ -456,13 +416,9 @@ class QwenImageTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Get modulation parameters for both streams
img_mod_params = self.img_mod(temb) # [B, 6*dim]
if self.zero_cond_t:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
# Split modulation parameters for norm1 and norm2
@@ -471,7 +427,7 @@ class QwenImageTransformerBlock(nn.Module):
# Process image stream - norm1 + modulation
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
# Process text stream - norm1 + modulation
txt_normed = self.txt_norm1(encoder_hidden_states)
@@ -501,7 +457,7 @@ class QwenImageTransformerBlock(nn.Module):
# Process image stream - norm2 + MLP
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
@@ -577,7 +533,6 @@ class QwenImageTransformer2DModel(
joint_attention_dim: int = 3584,
guidance_embeds: bool = False, # TODO: this should probably be removed
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
zero_cond_t: bool = False,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -598,7 +553,6 @@ class QwenImageTransformer2DModel(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
zero_cond_t=zero_cond_t,
)
for _ in range(num_layers)
]
@@ -608,7 +562,6 @@ class QwenImageTransformer2DModel(
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
self.zero_cond_t = zero_cond_t
def forward(
self,
@@ -665,17 +618,6 @@ class QwenImageTransformer2DModel(
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if self.zero_cond_t:
timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
@@ -699,8 +641,6 @@ class QwenImageTransformer2DModel(
encoder_hidden_states_mask,
temb,
image_rotary_emb,
attention_kwargs,
modulate_index,
)
else:
@@ -711,7 +651,6 @@ class QwenImageTransformer2DModel(
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
modulate_index=modulate_index,
)
# controlnet residual
@@ -720,8 +659,6 @@ class QwenImageTransformer2DModel(
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
if self.zero_cond_t:
temb = temb.chunk(2, dim=0)[0]
# Use only the image part (hidden_states) from the dual-stream blocks
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

View File

@@ -291,7 +291,6 @@ else:
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
_import_structure["longcat_image"] = ["LongCatImagePipeline", "LongCatImageEditPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
@@ -719,7 +718,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline

View File

@@ -1,51 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa: F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_longcat_image"] = ["LongCatImagePipeline"]
_import_structure["pipeline_longcat_image_edit"] = ["LongCatImageEditPipeline"]
_import_structure["pipeline_output"] = ["LongCatImagePipelineOutput"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_longcat_image import LongCatImagePipeline
from .pipeline_longcat_image_edit import LongCatImageEditPipeline
from .pipeline_output import LongCatImagePipelineOutput
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -1,666 +0,0 @@
# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import LongCatImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import LongCatImagePipelineOutput
from .system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import LongCatImagePipeline
>>> pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。"
>>> image = pipe(
... prompt,
... height=768,
... width=1344,
... num_inference_steps=50,
... guidance_scale=4.5,
... generator=torch.Generator("cpu").manual_seed(43),
... enable_cfg_renorm=True,
... ).images[0]
>>> image.save("longcat_image.png")
```
"""
def get_prompt_language(prompt):
pattern = re.compile(r"[\u4e00-\u9fff]")
if bool(pattern.search(prompt)):
return "zh"
return "en"
def split_quotation(prompt, quote_pairs=None):
"""
Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote
pairs. Examples::
>>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> #
output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)]
"""
word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+")
matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt)
mapping_word_internal_quote = []
for i, word_src in enumerate(set(matches_word_internal_quote_pattern)):
word_tgt = "longcat_$##$_longcat" * (i + 1)
prompt = prompt.replace(word_src, word_tgt)
mapping_word_internal_quote.append([word_src, word_tgt])
if quote_pairs is None:
quote_pairs = [("'", "'"), ('"', '"'), ("", ""), ("", "")]
pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs])
parts = re.split(f"({pattern})", prompt)
result = []
for part in parts:
for word_src, word_tgt in mapping_word_internal_quote:
part = part.replace(word_tgt, word_src)
if re.match(pattern, part):
if len(part):
result.append((part, True))
else:
if len(part):
result.append((part, False))
return result
def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None):
if type == "text":
assert num_token
if height or width:
print('Warning: The parameters of height and width will be ignored in "text" type.')
pos_ids = torch.zeros(num_token, 3)
pos_ids[..., 0] = modality_id
pos_ids[..., 1] = torch.arange(num_token) + start[0]
pos_ids[..., 2] = torch.arange(num_token) + start[1]
elif type == "image":
assert height and width
if num_token:
print('Warning: The parameter of num_token will be ignored in "image" type.')
pos_ids = torch.zeros(height, width, 3)
pos_ids[..., 0] = modality_id
pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0]
pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1]
pos_ids = pos_ids.reshape(height * width, 3)
else:
raise KeyError(f'Unknow type {type}, only support "text" or "image".')
return pos_ids
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
The pipeline for text-to-image generation.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
text_processor: Qwen2VLProcessor,
transformer: LongCatImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
text_processor=text_processor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n"
self.default_sample_size = 128
self.tokenizer_max_length = 512
def rewire_prompt(self, prompt, device):
prompt = [prompt] if isinstance(prompt, str) else prompt
all_text = []
for each_prompt in prompt:
language = get_prompt_language(each_prompt)
if language == "zh":
question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{each_prompt}\n改写后的prompt为"
else:
question = SYSTEM_PROMPT_EN + f"\nUser Input: {each_prompt}\nRewritten prompt:"
message = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
],
}
]
# Preparation for inference
text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
all_text.append(text)
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device)
self.text_encoder.to(device)
generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = self.text_processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
rewrite_prompt = output_text
return rewrite_prompt
def _encode_prompt(self, prompt: List[str]):
batch_all_tokens = []
for each_prompt in prompt:
all_tokens = []
for clean_prompt_sub, matched in split_quotation(each_prompt):
if matched:
for sub_word in clean_prompt_sub:
tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"]
all_tokens.extend(tokens)
else:
tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"]
all_tokens.extend(tokens)
if len(all_tokens) > self.tokenizer_max_length:
logger.warning(
"Your input was truncated because `max_sequence_length` is set to "
f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}"
)
all_tokens = all_tokens[: self.tokenizer_max_length]
batch_all_tokens.append(all_tokens)
text_tokens_and_mask = self.tokenizer.pad(
{"input_ids": batch_all_tokens},
max_length=self.tokenizer_max_length,
padding="max_length",
return_attention_mask=True,
return_tensors="pt",
)
prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)["input_ids"]
suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"]
prefix_len = len(prefix_tokens)
suffix_len = len(suffix_tokens)
prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype)
suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype)
prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype)
suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype)
batch_size = text_tokens_and_mask.input_ids.size(0)
prefix_tokens_batch = prefix_tokens.unsqueeze(0).expand(batch_size, -1)
suffix_tokens_batch = suffix_tokens.unsqueeze(0).expand(batch_size, -1)
prefix_mask_batch = prefix_tokens_mask.unsqueeze(0).expand(batch_size, -1)
suffix_mask_batch = suffix_tokens_mask.unsqueeze(0).expand(batch_size, -1)
input_ids = torch.cat((prefix_tokens_batch, text_tokens_and_mask.input_ids, suffix_tokens_batch), dim=-1)
attention_mask = torch.cat((prefix_mask_batch, text_tokens_and_mask.attention_mask, suffix_mask_batch), dim=-1)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
# [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]
# clone to have a contiguous tensor
prompt_embeds = text_output.hidden_states[-1].detach()
prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :]
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: Optional[int] = 1,
prompt_embeds: Optional[torch.Tensor] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is None:
prompt_embeds = self._encode_prompt(prompt)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to(
self.device
)
return prompt_embeds.to(self.device), text_ids
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = prepare_pos_ids(
modality_id=1,
type="image",
start=(self.tokenizer_max_length, self.tokenizer_max_length),
height=height // 2,
width=width // 2,
).to(device)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device)
latents = latents.to(dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
def check_inputs(
self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
@replace_example_docstring(EXAMPLE_DOC_STRING)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
enable_cfg_renorm: Optional[bool] = True,
cfg_renorm_min: Optional[float] = 0.0,
enable_prompt_rewrite: Optional[bool] = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
enable_cfg_renorm: Whether to enable cfg_renorm. Enabling cfg_renorm will improve image quality,
but it may lead to a decrease in the stability of some image outputs..
cfg_renorm_min: The minimum value of the cfg_renorm_scale range (0-1).
cfg_renorm_min = 1.0, renorm has no effect, while cfg_renorm_min=0.0, the renorm range is larger.
enable_prompt_rewrite: whether to enable prompt rewrite.
Examples:
Returns:
[`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if enable_prompt_rewrite:
prompt = self.rewire_prompt(prompt, device)
logger.info(f"Rewrite prompt {prompt}!")
negative_prompt = "" if negative_prompt is None else negative_prompt
(prompt_embeds, text_ids) = self.encode_prompt(
prompt=prompt, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt
)
if self.do_classifier_free_guidance:
(negative_prompt_embeds, negative_text_ids) = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=num_images_per_prompt,
)
# 4. Prepare latent variables
num_channels_latents = 16
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
guidance = None
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
noise_pred_text = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if enable_cfg_renorm:
cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True)
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
noise_pred = noise_pred * scale
else:
noise_pred = noise_pred_text
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
if latents.dtype != self.vae.dtype:
latents = latents.to(dtype=self.vae.dtype)
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return LongCatImagePipelineOutput(images=image)

View File

@@ -1,727 +0,0 @@
# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import re
from typing import Any, Dict, List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import LongCatImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import LongCatImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from PIL import Image
>>> import torch
>>> from diffusers import LongCatImageEditPipeline
>>> pipe = LongCatImageEditPipeline.from_pretrained(
... "meituan-longcat/LongCat-Image-Edit", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "change the cat to dog."
>>> input_image = Image.open("test.jpg").convert("RGB")
>>> image = pipe(
... input_image,
... prompt,
... num_inference_steps=50,
... guidance_scale=4.5,
... generator=torch.Generator("cpu").manual_seed(43),
... ).images[0]
>>> image.save("longcat_image_edit.png")
```
"""
# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.split_quotation
def split_quotation(prompt, quote_pairs=None):
"""
Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote
pairs. Examples::
>>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> #
output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)]
"""
word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+")
matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt)
mapping_word_internal_quote = []
for i, word_src in enumerate(set(matches_word_internal_quote_pattern)):
word_tgt = "longcat_$##$_longcat" * (i + 1)
prompt = prompt.replace(word_src, word_tgt)
mapping_word_internal_quote.append([word_src, word_tgt])
if quote_pairs is None:
quote_pairs = [("'", "'"), ('"', '"'), ("", ""), ("", "")]
pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs])
parts = re.split(f"({pattern})", prompt)
result = []
for part in parts:
for word_src, word_tgt in mapping_word_internal_quote:
part = part.replace(word_tgt, word_src)
if re.match(pattern, part):
if len(part):
result.append((part, True))
else:
if len(part):
result.append((part, False))
return result
# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.prepare_pos_ids
def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None):
if type == "text":
assert num_token
if height or width:
print('Warning: The parameters of height and width will be ignored in "text" type.')
pos_ids = torch.zeros(num_token, 3)
pos_ids[..., 0] = modality_id
pos_ids[..., 1] = torch.arange(num_token) + start[0]
pos_ids[..., 2] = torch.arange(num_token) + start[1]
elif type == "image":
assert height and width
if num_token:
print('Warning: The parameter of num_token will be ignored in "image" type.')
pos_ids = torch.zeros(height, width, 3)
pos_ids[..., 0] = modality_id
pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0]
pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1]
pos_ids = pos_ids.reshape(height * width, 3)
else:
raise KeyError(f'Unknow type {type}, only support "text" or "image".')
return pos_ids
# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def calculate_dimensions(target_area, ratio):
width = math.sqrt(target_area * ratio)
height = width / ratio
width = width if width % 16 == 0 else (width // 16 + 1) * 16
height = height if height % 16 == 0 else (height // 16 + 1) * 16
width = int(width)
height = int(height)
return width, height
class LongCatImageEditPipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
The LongCat-Image-Edit pipeline for image editing.
"""
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
text_processor: Qwen2VLProcessor,
transformer: LongCatImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
text_processor=text_processor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.image_processor_vl = text_processor.image_processor
self.image_token = "<|image_pad|>"
self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n"
self.default_sample_size = 128
self.tokenizer_max_length = 512
def _encode_prompt(self, prompt, image):
raw_vl_input = self.image_processor_vl(images=image, return_tensors="pt")
pixel_values = raw_vl_input["pixel_values"]
image_grid_thw = raw_vl_input["image_grid_thw"]
all_tokens = []
for clean_prompt_sub, matched in split_quotation(prompt[0]):
if matched:
for sub_word in clean_prompt_sub:
tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"]
all_tokens.extend(tokens)
else:
tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"]
all_tokens.extend(tokens)
if len(all_tokens) > self.tokenizer_max_length:
logger.warning(
"Your input was truncated because `max_sequence_length` is set to "
f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}"
)
all_tokens = all_tokens[: self.tokenizer_max_length]
text_tokens_and_mask = self.tokenizer.pad(
{"input_ids": [all_tokens]},
max_length=self.tokenizer_max_length,
padding="max_length",
return_attention_mask=True,
return_tensors="pt",
)
text = self.prompt_template_encode_prefix
merge_length = self.image_processor_vl.merge_size**2
while self.image_token in text:
num_image_tokens = image_grid_thw.prod() // merge_length
text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
text = text.replace("<|placeholder|>", self.image_token)
prefix_tokens = self.tokenizer(text, add_special_tokens=False)["input_ids"]
suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"]
vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>")
prefix_len = prefix_tokens.index(vision_start_token_id)
suffix_len = len(suffix_tokens)
prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype)
suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype)
prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype)
suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype)
input_ids = torch.cat((prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1)
attention_mask = torch.cat(
(prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1
)
input_ids = input_ids.unsqueeze(0).to(self.device)
attention_mask = attention_mask.unsqueeze(0).to(self.device)
pixel_values = pixel_values.to(self.device)
image_grid_thw = image_grid_thw.to(self.device)
text_output = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
output_hidden_states=True,
)
# [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]
# clone to have a contiguous tensor
prompt_embeds = text_output.hidden_states[-1].detach()
prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :]
return prompt_embeds
def encode_prompt(
self,
prompt: List[str] = None,
image: Optional[torch.Tensor] = None,
num_images_per_prompt: Optional[int] = 1,
prompt_embeds: Optional[torch.Tensor] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is None:
prompt_embeds = self._encode_prompt(prompt, image)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to(
self.device
)
return prompt_embeds, text_ids
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
def prepare_latents(
self,
image,
batch_size,
num_channels_latents,
height,
width,
dtype,
prompt_embeds_length,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
image_latents, image_latents_ids = None, None
if image is not None:
image = image.to(device=self.device, dtype=dtype)
if image.shape[1] != self.vae.config.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
image_latents_ids = prepare_pos_ids(
modality_id=2,
type="image",
start=(prompt_embeds_length, prompt_embeds_length),
height=height // 2,
width=width // 2,
).to(device, dtype=torch.float64)
shape = (batch_size, num_channels_latents, height, width)
latents_ids = prepare_pos_ids(
modality_id=1,
type="image",
start=(prompt_embeds_length, prompt_embeds_length),
height=height // 2,
width=width // 2,
).to(device)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
else:
latents = latents.to(device=device, dtype=dtype)
return latents, image_latents, latents_ids, image_latents_ids
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
def check_inputs(
self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None:
if isinstance(prompt, str):
pass
elif isinstance(prompt, list) and len(prompt) == 1:
pass
else:
raise ValueError(
f"`prompt` must be a `str` or a `list` of length 1, but is {prompt} (type: {type(prompt)})"
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
@replace_example_docstring(EXAMPLE_DOC_STRING)
@torch.no_grad()
def __call__(
self,
image: Optional[PIL.Image.Image] = None,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Examples:
Returns:
[`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
image_size = image[0].size if isinstance(image, list) else image.size
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1])
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
calculated_height,
calculated_width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
image = self.image_processor.resize(image, calculated_height, calculated_width)
prompt_image = self.image_processor.resize(image, calculated_height // 2, calculated_width // 2)
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
negative_prompt = "" if negative_prompt is None else negative_prompt
(prompt_embeds, text_ids) = self.encode_prompt(
prompt=prompt, image=prompt_image, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt
)
if self.do_classifier_free_guidance:
(negative_prompt_embeds, negative_text_ids) = self.encode_prompt(
prompt=negative_prompt,
image=prompt_image,
prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=num_images_per_prompt,
)
# 4. Prepare latent variables
num_channels_latents = 16
latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents(
image,
batch_size * num_images_per_prompt,
num_channels_latents,
calculated_height,
calculated_width,
prompt_embeds.dtype,
prompt_embeds.shape[1],
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
guidance = None
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
if image is not None:
latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0)
else:
latent_image_ids = latents_ids
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
noise_pred_text = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
noise_pred_text = noise_pred_text[:, :image_seq_len]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
noise_pred_uncond = noise_pred_uncond[:, :image_seq_len]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred = noise_pred_text
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, calculated_height, calculated_width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
if latents.dtype != self.vae.dtype:
latents = latents.to(dtype=self.vae.dtype)
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return LongCatImagePipelineOutput(images=image)

View File

@@ -1,21 +0,0 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from diffusers.utils import BaseOutput
@dataclass
class LongCatImagePipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -1,142 +0,0 @@
SYSTEM_PROMPT_EN = """
You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in
understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's
understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all
information from the user's original prompt without deleting or distorting any details. Specific requirements are as
follows:
1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use
coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as
concise as possible.
2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields
English output. The rewritten token count should not exceed 512.
3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the
original prompt, such as lighting and textures.
4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography
style**. If the user specifies a style, retain the user's style.
5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge
to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a
giraffe").
6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50%
OFF"`).
7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no
specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For
example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer
with the image title 'Grassland'."
8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For
example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all.
9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**.
Here are examples of rewrites for different types of prompts: # Examples (Few-Shot Learning)
1. User Input: An animal with nine lives.
Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home
environment with light from the window filtering through curtains, creating a warm light and shadow effect. The
shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits
the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image.
2. User Input: Create an anime-style tourism flyer with a grassland theme.
Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped
rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her
left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs
covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To
the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The
grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies
the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is
a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue,
and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere.
3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer.
Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and
left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center,
golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two
transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls
scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries,
and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity,
accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing
strong visual appeal.
4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident
posture.
Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her
shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long
eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She
has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile.
Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a
black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and
metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible
knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a
relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute
focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots
on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional
and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are
dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere.
The overall style is natural, elegant, and artistic.
5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should
include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting.
Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage
precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark
soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with
green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil,
stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden
light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches
and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under
a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into
the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the
orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a
realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the
apple's life cycle.
6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate
a four-color rainbow based on this rule. The color order from top to bottom is 3142.
Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as
purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the
number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the
bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast
with the background colors to ensure good readability. The stripes have high color saturation and a slight texture.
The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing
the numerical information. The image is high definition, with accurate colors and a consistent style, offering
strong visual appeal.
7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a
Chinese garden.
Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with
traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the
stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo
forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a
realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the
stone tablet and the classical beauty of the garden.
# Output Format Please directly output the rewritten and optimized Prompt content. Do not include any explanatory
language or JSON formatting, and do not add opening or closing quotes yourself."""
SYSTEM_PROMPT_ZH = """
你是一名文生图模型的prompt
engineering专家。由于文生图模型对用户prompt的理解能力有限你需要识别用户输入的核心主题和意图并通过优化改写提升模型的理解准确性和生成质量。改写必须严格保留用户原始prompt的所有信息不得删减或曲解任何细节。
具体要求如下:
1. 改写不能影响用户原始prompt里表达的任何信息改写后的prompt应该使用连贯的自然语言表达,不要出现低信息量的冗余描述尽可能保持改写后prompt长度精简。
2. 请确保输入和输出的语言类型一致中文输入中文输出英文输入英文输出改写后的token数量不要超过512个;
3. 改写后的描述应当进一步完善原始prompt中出现的主体特征、美学技巧如打光、纹理等
4. 如果原始prompt没有指定图片风格时确保改写后的prompt使用真实摄影风格如果用户指定了图片风格则保留用户风格
5. 当原始prompt需要推理才能明确用户意图时根据世界知识进行适当逻辑推理将模糊抽象描述转化为具体指向事物"最高的动物"转化为"一头长颈鹿")。
6. 当原始prompt需要生成文字时请使用双引号圈定文字部分`"限时5折"`)。
7. 当原始prompt需要生成网页、logo、ui、海报等文字场景时且没有指定具体的文字内容时需要推断出合适的文字内容并使用双引号圈定如用户输入一个旅游宣传单以草原为主题。应该改写成一个旅游宣传单图片标题为“草原”。
8. 当原始prompt中存在否定词时需要确保改写后的prompt不存在否定词如没有船的湖边改写后的prompt不能出现船这个词汇。
9. 除非用户指定生成品牌logo否则不要增加额外的品牌logo.
10. 除了用户明确要求书写的文字内容外,**禁止增加任何额外的文字内容**。
以下是针对不同类型prompt改写的示例
# Examples (Few-Shot Learning)
1. 用户输入: 九条命的动物。
改写输出:
一只猫,被柔和的阳光笼罩着,毛发柔软而富有光泽。背景是一个舒适的家居环境,窗外的光线透过窗帘,形成温馨的光影效果。镜头采用中距离视角,突出猫悠闲舒展的姿态。光线巧妙地打在猫的脸部,强调它灵动的眼睛和精致的胡须,增加画面的层次感与亲和力。
2. 用户输入: 制作一个动画风格的旅游宣传单,以草原为主题。
改写输出:
画面中央偏右下角一个短发女孩侧身坐在灰色的不规则形状岩石上她穿着白色短袖连衣裙和棕色平底鞋左手拿着一束白色小花面带微笑双腿自然垂下。女孩的头发为深棕色齐肩短发刘海覆盖额头眼睛呈棕色嘴巴微张。岩石表面有深浅不一的纹理。女孩的左侧和前方是茂盛的草地草叶细长呈黄绿色部分草叶在阳光下泛着金色的光芒仿佛被阳光照亮。草地向远处延伸形成连绵起伏的绿色山丘山丘的颜色由近及远逐渐变浅。天空占据了画面的上半部分呈淡蓝色点缀着几朵白色蓬松的云彩。画面的左上角有一行文字文字内容是斜体、深绿色的“Explore
Nature's Peace”。色彩以绿色、蓝色和黄色为主线条流畅光影明暗对比明显营造出一种宁静、舒适的氛围。
3. 用户输入: 一张以红色为背景的圣诞节促销海报,主要宣传奶茶买一送一的优惠活动。
改写输出: 海报整体呈现红色调,上方和左侧点缀着白色雪花图案,右上方有一束冬青叶和红色浆果,以及一个松果。海报中央偏上位置,金色立体字样“圣诞节
暖心回馈”居中排列和红色粗体字“买1送1”。海报下方两个装满珍珠奶茶的透明杯子并排摆放杯中奶茶呈浅棕色底部和中间散布着深棕色珍珠。杯子下方堆积着白色雪花雪花上装饰着松枝、红色浆果和松果。右下角隐约可见一棵模糊的圣诞树。图片清晰度高文字内容准确整体设计风格统一圣诞主题突出排版布局合理具有较强的视觉吸引力。
4. 用户输入: 一位女性在室内以自然光线拍摄,她面带微笑,双臂交叉,展现出轻松自信的姿态。
改写输出:
画面中是一位年轻的亚洲女性,她拥有深棕色的长发,发丝自然地垂落在双肩,部分发丝被光线照亮,呈现出柔和的光泽。她的五官精致,眉毛修长,眼睛明亮有神,瞳孔呈深棕色,眼神直视镜头,流露出平和与自信。鼻梁挺拔,嘴唇丰满,涂有裸色系唇膏,嘴角微微上扬,展现出浅浅的微笑。她的肤色白皙,脸颊和锁骨处被暖色调的光线照亮,呈现出健康的红润感。她穿着一件黑色的细吊带背心,肩带纤细,露出优美的锁骨线条。脖颈上佩戴着一条金色的细项链,项链由小珠子和几个细长的金属条组成,在光线下闪烁着光泽。她的外搭是一件米黄色的针织开衫,材质柔软,袖子部分有明显的针织纹理。她双臂交叉在胸前,双手被开衫的袖子覆盖,姿态放松。背景是纯粹的深棕色,没有多余的装饰,使得人物成为画面的绝对焦点。人物位于画面中央。光线从画面的右上方射入,在人物的左侧脸颊、脖颈和锁骨处形成明亮的光斑,右侧则略显阴影,营造出立体感和柔和的影调。图像细节清晰,人物的皮肤纹理、发丝以及衣物材质都得到了很好的展现。色彩以暖色调为主,米黄色和深棕色的搭配营造出温馨舒适的氛围。整体呈现出一种自然、优雅且富有亲和力的艺术风格。
5. 用户输入创作一系列图片展现苹果从种子到结果的生长过程。该系列图片应包含以下四个阶段1. 播种2. 幼苗生长3. 植物成熟4. 果实采摘。
改写输出一个4宫格的精美插图描绘苹果的生长过程精确清晰地捕捉每个阶段。1.“播种”特写镜头一只手轻轻地将一颗小小的苹果种子放入肥沃的深色土壤中土壤的纹理和种子光滑的表面清晰可见。背景是花园的柔焦画面点缀着绿色的树叶和透过树叶洒下的阳光。2.“幼苗生长”一棵幼小的苹果树苗破土而出嫩绿的叶子向天空舒展。场景设定在一个生机勃勃的花园中温暖的金光照亮了它。幼苗的纤细结构。3.“植物的成熟”一棵成熟的苹果树枝繁叶茂挂满了嫩绿的叶子和正在萌发的小苹果。背景是一片生机勃勃的果园湛蓝的天空下斑驳的阳光营造出宁静祥和的氛围。4.“采摘果实”:一只手伸向树上,摘下一个成熟的红苹果,苹果光滑的果皮在阳光下闪闪发光。画面展现了果园的丰收景象,背景中摆放着一篮篮的苹果,给人一种圆满满足的感觉。每幅插图都采用写实风格,注重细节,色彩和谐,展现了苹果生命周期的自然之美和发展过程。
6. 用户输入: 如果1代表红色2代表绿色3代表紫色4代表黄色请按照此规则生成四色彩虹。它的颜色顺序从上到下是3142
改写输出图片由四个水平排列的彩色条纹组成从上到下依次为紫色、红色、黄色和绿色。每个条纹上都居中放置一个白色数字。最上方的紫色条纹上是数字“3”其下方红色条纹上是数字“1”再下方黄色条纹上是数字“4”最下方的绿色条纹上是数字“2”。所有数字均采用无衬线字体颜色为纯白色与背景色形成鲜明对比确保了良好的可读性。条纹的颜色饱和度高且带有轻微的纹理感整体排版简洁明了视觉效果清晰没有多余的装饰元素强调了数字信息本身。图片整体清晰度高色彩准确风格一致具有较强的视觉吸引力。
7. 用户输入:石碑上刻着“关关雎鸠,在河之洲”,自然光照,背景是中式园林
改写输出:一块古老的石碑上刻着“关关雎鸠,在河之洲”,石碑表面布满岁月的痕迹,字迹清晰而深刻。自然光线从上方洒下,柔和地照亮石碑的每一个细节,增强了其历史感。背景是一座典雅的中式园林,园林中有翠绿的竹林、蜿蜒的小径和静谧的水池,营造出一种宁静而悠远的氛围。整体画面采用写实风格,细节丰富,光影效果自然,突出了石碑的文化底蕴和园林的古典美。
# 输出格式 请直接输出改写优化后的 Prompt 内容,不要包含任何解释性语言或 JSON 格式,不要自行添加开头或结尾的引号。
"""

View File

@@ -1132,21 +1132,6 @@ class LatteTransformer3DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LongCatImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LTXVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1832,36 +1832,6 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LongCatImageEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LongCatImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -29,7 +29,6 @@ from diffusers import (
)
from ...testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -336,14 +335,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
expected_slices = Expectations(
{
("xpu", 3): np.array([0.0417, 0.0388, 0.0061, 0.0618, 0.0517, 0.0420, 0.1038, 0.1055, 0.1257]),
("cuda", None): np.array([0.0479, 0.0378, 0.0217, 0.0942, 0.064, 0.0791, 0.2073, 0.1975, 0.2017]),
}
)
expected_slice = expected_slices.get_expectation()
expected_slice = np.array([0.0479, 0.0378, 0.0217, 0.0942, 0.064, 0.0791, 0.2073, 0.1975, 0.2017])
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
self.assertLessEqual(max_diff, 1e-4)