Compare commits

...

19 Commits

Author SHA1 Message Date
Dhruv Nair
59f2784688 Revert "[Modular] Save Modular Pipeline weights to Hub (#13168)"
This reverts commit 84ff061b1d.
2026-03-03 00:05:54 +05:30
David El Malih
c2fdd2d048 docs: improve docstring scheduling_ipndm.py (#13198)
Improve docstring scheduling ipndm
2026-03-02 09:42:55 -08:00
Dhruv Nair
84ff061b1d [Modular] Save Modular Pipeline weights to Hub (#13168)
* update

* update

* update

* update

* update

* update
2026-03-02 22:20:42 +05:30
Dhruv Nair
3fd14f1acf [AutoModel] Allow registering auto_map to model config (#13186)
* update

* update
2026-03-02 22:13:25 +05:30
Dhruv Nair
e7fe4ce92f [AutoModel] Fix bug with subfolders and local model paths when loading custom code (#13197)
* update

* update
2026-03-02 17:44:25 +05:30
Sayak Paul
3d9085565b remove db utils from benchmarking (#13199) 2026-03-02 16:39:56 +05:30
Sayak Paul
5b54496131 [tests] enable cpu offload test in torchao without compilation. (#12704)
enable cpu offload test in torchao without compilation.
2026-03-02 15:03:58 +05:30
Sayak Paul
fcdd759e39 [chore] updates in the pypi publication workflow. (#12805)
* updates in the pypi publication workflow.

* change to 3.10
2026-03-02 14:34:49 +05:30
YiYi Xu
39188248a7 [modular] fallback to default_blocks_name when loading base block classes in ModularPipeline (#13193)
up

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
2026-02-27 18:58:01 -10:00
Sayak Paul
9b97932424 [tests] consistency tests for modular index (#13192)
* add a test to check modular index consistency

* check for compulsory keys.
2026-02-28 08:47:21 +05:30
YiYi Xu
680076fcc0 [Modular] update the auto pipeline blocks doc (#13148)
* update

* Apply suggestion from @yiyixuxu

* Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add to api

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
2026-02-27 10:50:35 -10:00
Christopher
5910a1cc6c Fixing Kohya loras loading: Flux.1-dev loras with TE ("lora_te1_" prefix) (#13188)
* fixing text encoder lora loading

* following Cursor's review
2026-02-27 15:43:41 +05:30
Jerry Song
40e96454f1 Fix LTX-2 image-to-video generation failure in two stages generation (#13187)
* Fix LTX-2 image-to-video generation failure in two stages generation

In LTX-2's two-stage image-to-video generation task, specifically after
the upsampling step, a shape mismatch occurs between the `latents` and
the `conditioning_mask`, which causes an error in function
`_create_noised_state`.

Fix it by creating the `conditioning_mask` based on the shape of the
`latents`.

* Add unit test for LTX-2 i2v two stages inference with upsampler

* Downscaling the upsampler in LTX-2 image-to-video unit test

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-27 00:55:01 -08:00
Varun Chawla
47455bd133 Fix Flash Attention 3 interface for new FA3 return format (#13173)
* Fix Flash Attention 3 interface compatibility for new FA3 versions

Newer versions of flash-attn (after Dao-AILab/flash-attention@ed20940)
no longer return lse by default from flash_attn_3_func. The function
now returns just the output tensor unless return_attn_probs=True is
passed.

Updated _wrapped_flash_attn_3 and _flash_varlen_attention_3 to pass
return_attn_probs and handle both old (always tuple) and new (tensor
or tuple) return formats gracefully.

Fixes #12022

* Simplify _wrapped_flash_attn_3 return unpacking

Since return_attn_probs=True is always passed, the result is
guaranteed to be a tuple. Remove the unnecessary isinstance guard.
2026-02-26 17:34:36 +05:30
Kirill Stukalov
97c2c6e397 Fix wrong do_classifier_free_guidance threshold in ZImagePipeline (#13183)
Z-Image uses CFG formula `pred = pos + scale * (pos - neg)` where
`guidance_scale = 0` means no guidance. The threshold should be `> 0`
instead of `> 1` to match this formula.

Co-authored-by: Hezlich2 <typretypre@gmail.com>
2026-02-25 15:08:11 -10:00
Miguel Martin
212db7b999 Cosmos Transfer2.5 Auto-Regressive Inference Pipeline (#13114)
* AR

* address comments

* address comments 2
2026-02-25 14:42:29 -10:00
Sayak Paul
31058485f1 [attention backends] use dedicated wrappers from fa3 for cp. (#13165)
* use dedicated wrappers from fa3 for cp.

* up
2026-02-26 00:36:01 +05:30
Sayak Paul
aac94befce [docs] Fix torchrun command argument order in docs (#13181)
Fix torchrun command argument order in docs
2026-02-24 08:31:39 -08:00
SYM.BOT
1f6ac1c3d1 fix: graceful fallback when attention backends fail to import (#13060)
* fix: graceful fallback when attention backends fail to import

## Problem

External attention backends (flash_attn, xformers, sageattention, etc.) may be
installed but fail to import at runtime due to ABI mismatches. For example,
when `flash_attn` is compiled against PyTorch 2.4 but used with PyTorch 2.8,
the import fails with:

```
OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab
```

The current code uses `importlib.util.find_spec()` to check if packages exist,
but this only verifies the package is installed—not that it can actually be
imported. When the import fails, diffusers crashes instead of falling back to
native PyTorch attention.

## Solution

Wrap all external attention backend imports in try-except blocks that catch
`ImportError` and `OSError`. On failure:
1. Log a warning message explaining the issue
2. Set the corresponding `_CAN_USE_*` flag to `False`
3. Set the imported functions to `None`

This allows diffusers to gracefully degrade to PyTorch's native SDPA
(scaled_dot_product_attention) instead of crashing.

## Affected backends

- flash_attn (Flash Attention)
- flash_attn_3 (Flash Attention 3)
- aiter (AMD Instinct)
- sageattention (SageAttention)
- flex_attention (PyTorch Flex Attention)
- torch_npu (Huawei NPU)
- torch_xla (TPU/XLA)
- xformers (Meta xFormers)

## Testing

Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4).
Before: crashes on import. After: logs warning and uses native attention.

* address review: use single logger and catch RuntimeError

- Move logger to module level instead of creating per-backend loggers
- Add RuntimeError to exception list alongside ImportError and OSError

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Apply style fixes

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-24 13:37:39 +05:30
25 changed files with 1069 additions and 567 deletions

View File

@@ -62,20 +62,6 @@ jobs:
with:
name: benchmark_test_reports
path: benchmarks/${{ env.BASE_PATH }}
# TODO: enable this once the connection problem has been resolved.
- name: Update benchmarking results to DB
env:
PGDATABASE: metrics
PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
PGUSER: transformers_benchmarks
PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
run: |
git config --global --add safe.directory /__w/diffusers/diffusers
commit_id=$GITHUB_SHA
commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
- name: Report success status
if: ${{ success() }}

View File

@@ -54,7 +54,6 @@ jobs:
python -m pip install --upgrade pip
pip install -U setuptools wheel twine
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
pip install -U transformers
- name: Build the dist files
run: python setup.py bdist_wheel && python setup.py sdist
@@ -69,6 +68,8 @@ jobs:
run: |
pip install diffusers && pip uninstall diffusers -y
pip install -i https://test.pypi.org/simple/ diffusers
pip install -U transformers
python utils/print_env.py
python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"

View File

@@ -1,166 +0,0 @@
import argparse
import os
import sys
import gpustat
import pandas as pd
import psycopg2
import psycopg2.extras
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json
register_adapter(dict, Json)
FINAL_CSV_FILENAME = "collated_results.csv"
# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
BENCHMARKS_TABLE_NAME = "benchmarks"
MEASUREMENTS_TABLE_NAME = "model_measurements"
def _init_benchmark(conn, branch, commit_id, commit_msg):
gpu_stats = gpustat.GPUStatCollection.new_query()
metadata = {"gpu_name": gpu_stats[0]["name"]}
repository = "huggingface/diffusers"
with conn.cursor() as cur:
cur.execute(
f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
(repository, branch, commit_id, commit_msg, metadata),
)
benchmark_id = cur.fetchone()[0]
print(f"Initialised benchmark #{benchmark_id}")
return benchmark_id
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"branch",
type=str,
help="The branch name on which the benchmarking is performed.",
)
parser.add_argument(
"commit_id",
type=str,
help="The commit hash on which the benchmarking is performed.",
)
parser.add_argument(
"commit_msg",
type=str,
help="The commit message associated with the commit, truncated to 70 characters.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
try:
conn = psycopg2.connect(
host=os.getenv("PGHOST"),
database=os.getenv("PGDATABASE"),
user=os.getenv("PGUSER"),
password=os.getenv("PGPASSWORD"),
)
print("DB connection established successfully.")
except Exception as e:
print(f"Problem during DB init: {e}")
sys.exit(1)
try:
benchmark_id = _init_benchmark(
conn=conn,
branch=args.branch,
commit_id=args.commit_id,
commit_msg=args.commit_msg,
)
except Exception as e:
print(f"Problem during initializing benchmark: {e}")
sys.exit(1)
cur = conn.cursor()
df = pd.read_csv(FINAL_CSV_FILENAME)
# Helper to cast values (or None) given a dtype
def _cast_value(val, dtype: str):
if pd.isna(val):
return None
if dtype == "text":
return str(val).strip()
if dtype == "float":
try:
return float(val)
except ValueError:
return None
if dtype == "bool":
s = str(val).strip().lower()
if s in ("true", "t", "yes", "1"):
return True
if s in ("false", "f", "no", "0"):
return False
if val in (1, 1.0):
return True
if val in (0, 0.0):
return False
return None
return val
try:
rows_to_insert = []
for _, row in df.iterrows():
scenario = _cast_value(row.get("scenario"), "text")
model_cls = _cast_value(row.get("model_cls"), "text")
num_params_B = _cast_value(row.get("num_params_B"), "float")
flops_G = _cast_value(row.get("flops_G"), "float")
time_plain_s = _cast_value(row.get("time_plain_s"), "float")
mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
time_compile_s = _cast_value(row.get("time_compile_s"), "float")
mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
fullgraph = _cast_value(row.get("fullgraph"), "bool")
mode = _cast_value(row.get("mode"), "text")
# If "github_sha" column exists in the CSV, cast it; else default to None
if "github_sha" in df.columns:
github_sha = _cast_value(row.get("github_sha"), "text")
else:
github_sha = None
measurements = {
"scenario": scenario,
"model_cls": model_cls,
"num_params_B": num_params_B,
"flops_G": flops_G,
"time_plain_s": time_plain_s,
"mem_plain_GB": mem_plain_GB,
"time_compile_s": time_compile_s,
"mem_compile_GB": mem_compile_GB,
"fullgraph": fullgraph,
"mode": mode,
"github_sha": github_sha,
}
rows_to_insert.append((benchmark_id, measurements))
# Batch-insert all rows
insert_sql = f"""
INSERT INTO {MEASUREMENTS_TABLE_NAME} (
benchmark_id,
measurements
)
VALUES (%s, %s);
"""
psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print(f"Exception: {e}")
sys.exit(1)

View File

@@ -14,4 +14,8 @@
## AutoPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
## ConditionalPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConditionalPipelineBlocks

View File

@@ -46,6 +46,20 @@ output = pipe(
output.save("output.png")
```
## Cosmos2_5_TransferPipeline
[[autodoc]] Cosmos2_5_TransferPipeline
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
@@ -70,12 +84,6 @@ output.save("output.png")
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

View File

@@ -121,7 +121,7 @@ from diffusers.modular_pipelines import AutoPipelineBlocks
class AutoImageBlocks(AutoPipelineBlocks):
# List of sub-block classes to choose from
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
# Names for each block in the same order
block_names = ["inpaint", "img2img", "text2img"]
# Trigger inputs that determine which block to run
@@ -129,8 +129,8 @@ class AutoImageBlocks(AutoPipelineBlocks):
# - "image" triggers img2img workflow (but only if mask is not provided)
# - if none of above, runs the text2img workflow (default)
block_trigger_inputs = ["mask", "image", None]
# Description is extremely important for AutoPipelineBlocks
@property
def description(self):
return (
"Pipeline generates images given different types of conditions!\n"
@@ -141,7 +141,7 @@ class AutoImageBlocks(AutoPipelineBlocks):
)
```
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, it's conditional logic may be difficult to figure out if it isn't properly explained.
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, its conditional logic may be difficult to figure out if it isn't properly explained.
Create an instance of `AutoImageBlocks`.
@@ -152,5 +152,74 @@ auto_blocks = AutoImageBlocks()
For more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input.
```py
auto_blocks.get_execution_blocks("mask")
auto_blocks.get_execution_blocks(mask=True)
```
## ConditionalPipelineBlocks
[`~modular_pipelines.AutoPipelineBlocks`] is a special case of [`~modular_pipelines.ConditionalPipelineBlocks`]. While [`~modular_pipelines.AutoPipelineBlocks`] selects blocks based on whether a trigger input is provided or not, [`~modular_pipelines.ConditionalPipelineBlocks`] is able to select a block based on custom selection logic provided in the `select_block` method.
Here is the same example written using [`~modular_pipelines.ConditionalPipelineBlocks`] directly:
```py
from diffusers.modular_pipelines import ConditionalPipelineBlocks
class AutoImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return (
"Pipeline generates images given different types of conditions!\n"
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
+ " - inpaint workflow is run when `mask` is provided.\n"
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
)
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name ("text2img")
```
The inputs listed in `block_trigger_inputs` are passed as keyword arguments to `select_block()`. When `select_block` returns `None`, it falls back to `default_block_name`. If `default_block_name` is also `None`, the entire conditional block is skipped — this is useful for optional processing steps that should only run when specific inputs are provided.
## Workflows
Pipelines that contain conditional blocks ([`~modular_pipelines.AutoPipelineBlocks`] or [`~modular_pipelines.ConditionalPipelineBlocks]`) can support multiple workflows — for example, our SDXL modular pipeline supports a dozen workflows all in one pipeline. But this also means it can be confusing for users to know what workflows are supported and how to run them. For pipeline builders, it's useful to be able to extract only the blocks relevant to a specific workflow.
We recommend defining a `_workflow_map` to give each workflow a name and explicitly list the inputs it requires.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks
class MyPipelineBlocks(SequentialPipelineBlocks):
block_classes = [TextEncoderBlock, AutoImageBlocks, DecodeBlock]
block_names = ["text_encoder", "auto_image", "decode"]
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"image": True, "prompt": True},
"inpaint": {"mask": True, "image": True, "prompt": True},
}
```
All of our built-in modular pipelines come with pre-defined workflows. The `available_workflows` property lists all supported workflows:
```py
pipeline_blocks = MyPipelineBlocks()
pipeline_blocks.available_workflows
# ['text2image', 'image2image', 'inpaint']
```
Retrieve a specific workflow with `get_workflow` to inspect and debug a specific block that executes the workflow.
```py
pipeline_blocks.get_workflow("inpaint")
```

View File

@@ -111,7 +111,7 @@ if __name__ == "__main__":
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
```bash
torchrun run_distributed.py --nproc_per_node=2
torchrun --nproc_per_node=2 run_distributed.py
```
## device_map

View File

@@ -97,5 +97,32 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
> )
> ```
### Saving custom models
Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.
```py
# my_model.py
from diffusers import ModelMixin, ConfigMixin
class MyCustomModel(ModelMixin, ConfigMixin):
...
MyCustomModel.register_for_auto_class("AutoModel")
model = MyCustomModel(...)
model.save_pretrained("./my_model")
```
The saved `config.json` will include the `auto_map` field.
```json
{
"auto_map": {
"AutoModel": "my_model.MyCustomModel"
}
}
```
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.

View File

@@ -94,9 +94,15 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth \
--output_path converted/transfer/2b/general/depth/pipeline \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth/models
# edge
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
@@ -120,9 +126,15 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur \
--output_path converted/transfer/2b/general/blur/pipeline \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur/models
# seg
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
@@ -130,8 +142,14 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg \
--output_path converted/transfer/2b/general/seg/pipeline \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg/models
```
"""

View File

@@ -107,6 +107,38 @@ class ConfigMixin:
has_compatibles = False
_deprecated_kwargs = []
_auto_class = None
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
"""
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
trust_remote_code=True)`.
When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
to this class's module and class name.
Args:
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
Currently only `"AutoModel"` is supported.
Example:
```python
from diffusers import ModelMixin, ConfigMixin
class MyCustomModel(ModelMixin, ConfigMixin): ...
MyCustomModel.register_for_auto_class("AutoModel")
```
"""
if auto_class != "AutoModel":
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")
cls._auto_class = auto_class
def register_to_config(self, **kwargs):
if self.config_name is None:
@@ -621,6 +653,12 @@ class ConfigMixin:
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = config_dict.pop("_pre_quantization_dtype", None)
if getattr(self, "_auto_class", None) is not None:
module = self.__class__.__module__.split(".")[-1]
auto_map = config_dict.get("auto_map", {})
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
config_dict["auto_map"] = auto_map
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: str | os.PathLike):

View File

@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_"))
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
}
if any("text_projection" in k for k in state_dict):

View File

@@ -62,6 +62,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"
logger = get_logger(__name__) # pylint: disable=invalid-name
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
@@ -73,8 +75,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
except (ImportError, OSError, RuntimeError) as e:
# Handle ABI mismatch or other import failures gracefully.
# This can happen when flash_attn was compiled against a different PyTorch version.
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
_CAN_USE_FLASH_ATTN = False
flash_attn_func = None
flash_attn_varlen_func = None
_wrapped_flash_attn_backward = None
_wrapped_flash_attn_forward = None
else:
flash_attn_func = None
flash_attn_varlen_func = None
@@ -83,26 +95,47 @@ else:
if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLASH_ATTN_3 = False
flash_attn_3_func = None
flash_attn_3_varlen_func = None
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
try:
from aiter import flash_attn_func as aiter_flash_attn_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
_CAN_USE_AITER_ATTN = False
aiter_flash_attn_func = None
else:
aiter_flash_attn_func = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
try:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
_CAN_USE_SAGE_ATTN = False
sageattn = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_varlen = None
else:
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
@@ -113,26 +146,48 @@ else:
if _CAN_USE_FLEX_ATTN:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
try:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLEX_ATTN = False
flex_attention = None
else:
flex_attention = None
if _CAN_USE_NPU_ATTN:
from torch_npu import npu_fusion_attention
try:
from torch_npu import npu_fusion_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
_CAN_USE_NPU_ATTN = False
npu_fusion_attention = None
else:
npu_fusion_attention = None
if _CAN_USE_XLA_ATTN:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
try:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
_CAN_USE_XLA_ATTN = False
xla_flash_attention = None
else:
xla_flash_attention = None
if _CAN_USE_XFORMERS_ATTN:
import xformers.ops as xops
try:
import xformers.ops as xops
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
_CAN_USE_XFORMERS_ATTN = False
xops = None
else:
xops = None
@@ -158,8 +213,6 @@ else:
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
@@ -276,7 +329,11 @@ class _HubKernelConfig:
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_func",
revision="fake-ops-return-probs",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
@@ -676,7 +733,7 @@ def _wrapped_flash_attn_3(
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
result = flash_attn_3_func(
q=q,
k=k,
v=v,
@@ -693,7 +750,9 @@ def _wrapped_flash_attn_3(
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=True,
)
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
return out, lse
@@ -1237,36 +1296,62 @@ def _flash_attention_3_hub_forward_op(
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
if wrapped_forward_fn is None:
raise RuntimeError(
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
out, softmax_lse, *_ = wrapped_forward_fn(
query,
key,
value,
None,
None, # k_new, v_new
None, # qv
None, # out
None,
None,
None, # cu_seqlens_q/k/k_new
None,
None, # seqused_q/k
None,
None, # max_seqlen_q/k
None,
None,
None, # page_table, kv_batch_idx, leftpad_k
None,
None,
None, # rotary_cos/sin, seqlens_rotary
None,
None,
None, # q_descale, k_descale, v_descale
scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=0,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.save_for_backward(query, key, value, out, softmax_lse)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
ctx.window_size = window_size
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin
return (out, lse) if return_lse else out
@@ -1275,54 +1360,49 @@ def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
**kwargs,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
"for context parallel execution."
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
query, key, value, out, softmax_lse = ctx.saved_tensors
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
softmax_lse,
None,
None, # cu_seqlens_q, cu_seqlens_k
None,
None, # seqused_q, seqused_k
None,
None, # max_seqlen_q, max_seqlen_k
grad_query,
grad_key,
grad_value,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
@@ -2623,7 +2703,7 @@ def _flash_varlen_attention_3(
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out, lse, *_ = flash_attn_3_varlen_func(
result = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2633,7 +2713,13 @@ def _flash_varlen_attention_3(
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out

View File

@@ -191,7 +191,12 @@ class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
dim=1,
)
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
if condition_mask is not None:
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
else:
control_hidden_states = torch.cat(
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
)
padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST

View File

@@ -1633,7 +1633,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
blocks_class = getattr(diffusers_module, blocks_class_name, None)
# If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes
# fall back to default_blocks_name
if blocks_class is None or not blocks_class.block_classes:
blocks_class_name = self.default_blocks_name
blocks_class = getattr(diffusers_module, blocks_class_name)
if blocks_class is not None:
blocks = blocks_class()
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")

View File

@@ -17,9 +17,6 @@ from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -54,11 +51,13 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int):
n_pad_frames = num_frames - video.shape[2]
if n_pad_frames > 0:
last_frame = video[:, :, -1:, :, :]
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
elif num_frames < video.shape[2]:
video = video[:, :, :num_frames, :, :]
return video
@@ -134,8 +133,8 @@ EXAMPLE_DOC_STRING = """
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
>>> # Transfer inference with controls.
>>> video = pipe(
... video=input_video[:num_frames],
... controls=controls,
... controls_conditioning_scale=1.0,
... prompt=prompt,
@@ -149,7 +148,7 @@ EXAMPLE_DOC_STRING = """
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
r"""
Pipeline for Cosmos Transfer2.5 base model.
Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -166,12 +165,14 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
controlnet ([`CosmosControlNetModel`]):
ControlNet used to condition generation on control inputs.
"""
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker", "controlnet"]
_optional_components = ["safety_checker"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
@@ -181,8 +182,8 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
controlnet: Optional[CosmosControlNetModel],
safety_checker: CosmosSafetyChecker = None,
controlnet: CosmosControlNetModel,
safety_checker: Optional[CosmosSafetyChecker] = None,
):
super().__init__()
@@ -384,10 +385,11 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
num_frames_in: int = 93,
num_frames_out: int = 93,
do_classifier_free_guidance: bool = True,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
num_cond_latent_frames: int = 0,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -402,10 +404,14 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
W = width // self.vae_scale_factor_spatial
shape = (B, C, T, H, W)
if num_frames_in == 0:
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if latents is not None:
if latents.shape[1:] != shape[1:]:
raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.")
latents = latents.to(device=device, dtype=dtype)
else:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if num_frames_in == 0:
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
@@ -435,16 +441,12 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
latents_std = self.latents_std.to(device=device, dtype=dtype)
cond_latents = (cond_latents - latents_mean) / latents_std
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
padding_shape = (B, 1, T, H, W)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1)
cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
return (
@@ -454,34 +456,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
cond_indicator,
)
def _encode_controls(
self,
controls: Optional[torch.Tensor],
height: int,
width: int,
num_frames: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator | list[torch.Generator] | None,
) -> Optional[torch.Tensor]:
if controls is None:
return None
control_video = self.video_processor.preprocess_video(controls, height, width)
control_video = _maybe_pad_video(control_video, num_frames)
control_video = control_video.to(device=device, dtype=self.vae.dtype)
control_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
]
control_latents = torch.cat(control_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
control_latents = (control_latents - latents_mean) / latents_std
return control_latents
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -489,9 +464,25 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
num_ar_conditional_frames=None,
num_ar_latent_conditional_frames=None,
num_frames_per_chunk=None,
num_frames=None,
conditional_frame_timestep=0.1,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}."
)
if num_frames is not None and num_frames <= 0:
raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.")
if conditional_frame_timestep < 0 or conditional_frame_timestep > 1:
raise ValueError(
"`conditional_frame_timestep` has to be a float in the [0, 1] interval but is "
f"{conditional_frame_timestep}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -512,6 +503,46 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None:
raise ValueError(
"Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both."
)
if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None:
raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.")
if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0:
raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.")
if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0:
raise ValueError("`num_ar_conditional_frames` must be >= 0.")
if num_ar_latent_conditional_frames is not None:
num_ar_conditional_frames = max(
0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1
)
min_chunk_len = self.vae_scale_factor_temporal + 1
if num_frames_per_chunk < min_chunk_len:
logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len")
num_frames_per_chunk = min_chunk_len
max_frames_by_rope = None
if getattr(self.transformer.config, "max_size", None) is not None:
max_frames_by_rope = max(
size // patch
for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size)
)
if num_frames_per_chunk > max_frames_by_rope:
raise ValueError(
f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). "
"Please reduce `num_frames_per_chunk`."
)
if num_ar_conditional_frames >= num_frames_per_chunk:
raise ValueError(
f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation."
)
return num_frames_per_chunk
@property
def guidance_scale(self):
return self._guidance_scale
@@ -536,23 +567,22 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput | None = None,
video: List[PipelineImageInput] | None = None,
controls: PipelineImageInput | List[PipelineImageInput],
controls_conditioning_scale: Union[float, List[float]] = 1.0,
prompt: Union[str, List[str]] | None = None,
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
height: int = 704,
width: int | None = None,
num_frames: int = 93,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_frames_per_chunk: int = 93,
num_inference_steps: int = 36,
guidance_scale: float = 3.0,
num_videos_per_prompt: Optional[int] = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
controls_conditioning_scale: float | list[float] = 1.0,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
output_type: str = "pil",
num_videos_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -560,24 +590,26 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
conditional_frame_timestep: float = 0.1,
num_ar_conditional_frames: Optional[int] = 1,
num_ar_latent_conditional_frames: Optional[int] = None,
):
r"""
The call function to the pipeline for generation. Supports three modes:
`controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps
are pre-computed.
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None
(default) then the number of output frames will match the input `controls`.
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
above in "*2Image mode").
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per
denoising loop. In addition, when auto-regressive inference is performed, the previous
`num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising
inference loops.
Args:
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
controls (`PipelineImageInput`, `List[PipelineImageInput]`):
Control image or video input used by the ControlNet.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
height (`int`, defaults to `704`):
@@ -585,9 +617,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
width (`int`, *optional*):
The width in pixels of the generated image. If not provided, this will be determined based on the
aspect ratio of the input and the provided height.
num_frames (`int`, defaults to `93`):
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
num_inference_steps (`int`, defaults to `35`):
num_frames (`int`, *optional*):
Number of output frames. Defaults to `None` to output the same number of frames as the input
`controls`.
num_inference_steps (`int`, defaults to `36`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `3.0`):
@@ -601,13 +634,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to
tweak the same generation with different prompts. If not provided, a latents tensor is generated by
sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -630,7 +659,18 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
num_ar_conditional_frames (`int`, *optional*, defaults to `1`):
Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the
second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`.
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
controls is > num_frames_per_chunk
num_ar_latent_conditional_frames (`int`, *optional*):
Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e.
for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`.
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
controls is > num_frames_per_chunk
Examples:
Returns:
@@ -650,21 +690,40 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if width is None:
frame = image or video[0] if image or video else None
if frame is None and controls is not None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
frame = controls[0]
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, list):
frame = frame[0]
if isinstance(frame, (torch.Tensor, np.ndarray)):
if frame.ndim == 5:
frame = frame[0, 0]
elif frame.ndim == 4:
frame = frame[0]
if frame is None:
width = int((height + 16) * (1280 / 720))
elif isinstance(frame, PIL.Image.Image):
if isinstance(frame, PIL.Image.Image):
width = int((height + 16) * (frame.width / frame.height))
else:
if frame.ndim != 3:
raise ValueError("`controls` must contain 3D frames in CHW format.")
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
num_frames_per_chunk = self.check_inputs(
prompt,
height,
width,
prompt_embeds,
callback_on_step_end_tensor_inputs,
num_ar_conditional_frames,
num_ar_latent_conditional_frames,
num_frames_per_chunk,
num_frames,
conditional_frame_timestep,
)
if num_ar_latent_conditional_frames is not None:
num_cond_latent_frames = num_ar_latent_conditional_frames
num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1)
else:
num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1)
self._guidance_scale = guidance_scale
self._current_timestep = None
@@ -709,102 +768,137 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
dtype=transformer_dtype,
)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
num_frames_in = None
if image is not None:
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
video = video.unsqueeze(0)
num_frames_in = 1
elif video is None:
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
num_frames_in = 0
else:
num_frames_in = len(video)
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
assert video is not None
video = self.video_processor.preprocess_video(video, height, width)
# pad with last frame (for video2world)
num_frames_out = num_frames
video = _maybe_pad_video(video, num_frames_out)
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames_in=num_frames_in,
num_frames_out=num_frames,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
cond_mask = cond_mask.to(transformer_dtype)
controls_latents = None
if controls is not None:
controls_latents = self._encode_controls(
controls,
height=height,
width=width,
num_frames=num_frames,
if getattr(self.transformer.config, "img_context_dim_in", None):
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
dtype=transformer_dtype,
device=device,
generator=generator,
)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
if num_videos_per_prompt > 1:
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
else:
encoder_hidden_states = prompt_embeds
neg_encoder_hidden_states = negative_prompt_embeds
gt_velocity = (latents - cond_latent) * cond_mask
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
control_video = self.video_processor.preprocess_video(controls, height, width)
if control_video.shape[0] != batch_size:
if control_video.shape[0] == 1:
control_video = control_video.repeat(batch_size, 1, 1, 1, 1)
else:
raise ValueError(
f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}."
)
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
num_frames_out = control_video.shape[2]
if num_frames is not None:
num_frames_out = min(num_frames_out, num_frames)
control_video = _maybe_pad_or_trim_video(control_video, num_frames_out)
# chunk information
num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1
chunk_stride = num_frames_per_chunk - num_ar_conditional_frames
chunk_idxs = [
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride)
]
video_chunks = []
latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device)
latents_std = self.latents_std.to(dtype=vae_dtype, device=device)
def decode_latents(latents):
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
return video
latents_arg = latents
initial_num_cond_latent_frames = 0
latent_chunks = []
num_chunks = len(chunk_idxs)
total_steps = num_inference_steps * num_chunks
with self.progress_bar(total=total_steps) as progress_bar:
for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs):
if chunk_idx == 0:
prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype)
prev_output = self.video_processor.preprocess_video(prev_output, height, width)
else:
prev_output = video_chunks[-1].clone()
if num_ar_conditional_frames > 0:
prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:]
prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space
else:
prev_output.fill_(-1)
chunk_video = prev_output.to(device=device, dtype=vae_dtype)
chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk)
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=chunk_video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=self.transformer.config.in_channels - 1,
height=height,
width=width,
num_frames_in=chunk_video.shape[2],
num_frames_out=num_frames_per_chunk,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
num_cond_latent_frames=initial_num_cond_latent_frames
if chunk_idx == 0
else num_cond_latent_frames,
latents=latents_arg,
)
cond_mask = cond_mask.to(transformer_dtype)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to(
device=device, dtype=self.vae.dtype
)
chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk)
if isinstance(generator, list):
controls_latents = [
retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i])
for i in range(chunk_control_video.shape[0])
]
else:
controls_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator)
for vid in chunk_control_video
]
controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype)
controls_latents = (controls_latents - latents_mean) / latents_std
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
gt_velocity = (latents - cond_latent) * cond_mask
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
@@ -817,20 +911,18 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
)
control_blocks = control_output[0]
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
if self.do_classifier_free_guidance:
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
if self.do_classifier_free_guidance:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
@@ -843,46 +935,50 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
)
control_blocks = control_output[0]
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if XLA_AVAILABLE:
xm.mark_step()
video_chunks.append(decode_latents(latents).detach().cpu())
latent_chunks.append(latents.detach().cpu())
self._current_timestep = None
if not output_type == "latent":
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
latents_std = self.latents_std.to(latents.device, latents.dtype)
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
video = self._match_num_frames(video, num_frames)
video_chunks = [
chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk
for chunk_idx, chunk in enumerate(video_chunks)
]
video = torch.cat(video_chunks, dim=2)
video = video[:, :, :num_frames_out, ...]
assert self.safety_checker is not None
self.safety_checker.to(device)
@@ -899,7 +995,13 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
latent_chunks = [
chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk
for chunk_idx, chunk in enumerate(latent_chunks)
]
video = torch.cat(latent_chunks, dim=2)
video = video[:, :, :latent_T, ...]
# Offload all models
self.maybe_free_model_hooks()
@@ -908,19 +1010,3 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
return (video,)
return CosmosPipelineOutput(frames=video)
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
return video
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
current_frames = video.shape[2]
if current_frames < target_num_frames:
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
video = torch.cat([video, pad], dim=2)
elif current_frames > target_num_frames:
video = video[:, :, :target_num_frames]
return video

View File

@@ -699,9 +699,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
if latents.ndim == 5:
# conditioning_mask needs to the same shape as latents in two stages generation.
batch_size, _, num_frames, height, width = latents.shape
mask_shape = (batch_size, 1, num_frames, height, width)
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
@@ -710,6 +714,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
else:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)

View File

@@ -276,7 +276,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 0
@property
def joint_attention_kwargs(self):

View File

@@ -31,14 +31,18 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
trained_betas (`np.ndarray`, *optional*):
trained_betas (`np.ndarray` or `List[float]`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
"""
order = 1
@register_to_config
def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray | list[float] | None = None):
def __init__(
self,
num_train_timesteps: int = 1000,
trained_betas: np.ndarray | list[float] | None = None,
):
# set `betas`, `alphas`, `timesteps`
self.set_timesteps(num_train_timesteps)
@@ -56,21 +60,29 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
@property
def step_index(self):
def step_index(self) -> int | None:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
Returns:
`int` or `None`:
The index counter for current timestep.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> int | None:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
Returns:
`int` or `None`:
The index for the first timestep.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -169,7 +181,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
@@ -228,7 +240,30 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
def _get_prev_sample(
self,
sample: torch.Tensor,
timestep_index: int,
prev_timestep_index: int,
ets: torch.Tensor,
) -> torch.Tensor:
"""
Predicts the previous sample based on the current sample, timestep indices, and running model outputs.
Args:
sample (`torch.Tensor`):
The current sample.
timestep_index (`int`):
Index of the current timestep in the schedule.
prev_timestep_index (`int`):
Index of the previous timestep in the schedule.
ets (`torch.Tensor`):
The running sequence of model outputs.
Returns:
`torch.Tensor`:
The predicted previous sample.
"""
alpha = self.alphas[timestep_index]
sigma = self.betas[timestep_index]
@@ -240,5 +275,5 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
return prev_sample
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -299,7 +299,10 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
if subfolder is not None:
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
else:
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url
@@ -384,7 +387,11 @@ def get_cached_module_file(
if not os.path.exists(submodule_path / module_folder):
os.makedirs(submodule_path / module_folder)
module_needed = f"{module_needed}.py"
shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
if subfolder is not None:
source_path = os.path.join(pretrained_model_name_or_path, subfolder, module_needed)
else:
source_path = os.path.join(pretrained_model_name_or_path, module_needed)
shutil.copyfile(source_path, submodule_path / module_needed)
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.

View File

@@ -131,6 +131,26 @@ class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase):
self.assertIsInstance(output[0], list)
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
def test_condition_mask_changes_output(self):
"""Test that condition mask affects control outputs."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
inputs_no_mask = dict(inputs_dict)
inputs_no_mask["condition_mask"] = torch.zeros_like(inputs_dict["condition_mask"])
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
output_with_mask = model(**inputs_dict)
self.assertEqual(len(output_no_mask.control_block_samples), len(output_with_mask.control_block_samples))
for no_mask_tensor, with_mask_tensor in zip(
output_no_mask.control_block_samples, output_with_mask.control_block_samples
):
self.assertFalse(torch.allclose(no_mask_tensor, with_mask_tensor))
def test_conditioning_scale_single(self):
"""Test that a single conditioning scale is broadcast to all blocks."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -1,9 +1,15 @@
import json
import os
import tempfile
import unittest
from unittest.mock import MagicMock, patch
import torch
from transformers import CLIPTextModel, LongformerModel
from diffusers import ConfigMixin
from diffusers.models import AutoModel, UNet2DConditionModel
from diffusers.models.modeling_utils import ModelMixin
class TestAutoModel(unittest.TestCase):
@@ -35,6 +41,45 @@ class TestAutoModel(unittest.TestCase):
)
assert isinstance(model, CLIPTextModel)
def test_load_dynamic_module_from_local_path_with_subfolder(self):
CUSTOM_MODEL_CODE = (
"import torch\n"
"from diffusers import ModelMixin, ConfigMixin\n"
"from diffusers.configuration_utils import register_to_config\n"
"\n"
"class CustomModel(ModelMixin, ConfigMixin):\n"
" @register_to_config\n"
" def __init__(self, hidden_size=8):\n"
" super().__init__()\n"
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
"\n"
" def forward(self, x):\n"
" return self.linear(x)\n"
)
with tempfile.TemporaryDirectory() as tmpdir:
subfolder = "custom_model"
model_dir = os.path.join(tmpdir, subfolder)
os.makedirs(model_dir)
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
f.write(CUSTOM_MODEL_CODE)
config = {
"_class_name": "CustomModel",
"_diffusers_version": "0.0.0",
"auto_map": {"AutoModel": "modeling.CustomModel"},
"hidden_size": 8,
}
with open(os.path.join(model_dir, "config.json"), "w") as f:
json.dump(config, f)
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
assert model.__class__.__name__ == "CustomModel"
assert model.config["hidden_size"] == 8
class TestAutoModelFromConfig(unittest.TestCase):
@patch(
@@ -100,3 +145,51 @@ class TestAutoModelFromConfig(unittest.TestCase):
def test_from_config_raises_on_none(self):
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
AutoModel.from_config(None)
class TestRegisterForAutoClass(unittest.TestCase):
def test_register_for_auto_class_sets_attribute(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"
DummyModel.register_for_auto_class("AutoModel")
self.assertEqual(DummyModel._auto_class, "AutoModel")
def test_register_for_auto_class_rejects_unsupported(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"
with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"):
DummyModel.register_for_auto_class("AutoPipeline")
def test_auto_map_in_saved_config(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"
DummyModel.register_for_auto_class("AutoModel")
model = DummyModel()
with tempfile.TemporaryDirectory() as tmpdir:
model.save_config(tmpdir)
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
self.assertIn("auto_map", config)
self.assertIn("AutoModel", config["auto_map"])
module_name = DummyModel.__module__.split(".")[-1]
self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel")
def test_no_auto_map_without_register(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"
model = DummyModel()
with tempfile.TemporaryDirectory() as tmpdir:
model.save_config(tmpdir)
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
self.assertNotIn("auto_map", config)

View File

@@ -1,4 +1,6 @@
import gc
import json
import os
import tempfile
from typing import Callable
@@ -349,6 +351,33 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_modular_index_consistency(self):
pipe = self.get_pipeline()
components_spec = pipe._component_specs
components = sorted(components_spec.keys())
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
index_file = os.path.join(tmpdir, "modular_model_index.json")
assert os.path.exists(index_file)
with open(index_file) as f:
index_contents = json.load(f)
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
for k in compulsory_keys:
assert k in index_contents
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
for component in components:
spec = components_spec[component]
for attr in to_check_attrs:
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
for attr in to_check_attrs:
assert component in index_contents, f"{component} should be present in index but isn't."
attr_value_from_index = index_contents[component][2][attr]
assert getattr(spec, attr) == attr_value_from_index
def test_workflow_map(self):
blocks = self.pipeline_blocks_class()
if blocks._workflow_map is None:
@@ -699,3 +728,27 @@ class TestLoadComponentsSkipBehavior:
# Verify test_component was not loaded
assert not hasattr(pipe, "test_component") or pipe.test_component is None
class TestModularPipelineInitFallback:
"""Test that ModularPipeline.__init__ falls back to default_blocks_name when
_blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict)."""
def test_init_fallback_when_blocks_class_name_is_base_class(self, tmp_path):
# 1. Load pipeline and get a workflow (returns a base SequentialPipelineBlocks)
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
t2i_blocks = pipe.blocks.get_workflow("text2image")
assert t2i_blocks.__class__.__name__ == "SequentialPipelineBlocks"
# 2. Use init_pipeline to create a new pipeline from the workflow blocks
t2i_pipe = t2i_blocks.init_pipeline("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
# 3. Save and reload — the saved config will have _blocks_class_name="SequentialPipelineBlocks"
save_dir = str(tmp_path / "pipeline")
t2i_pipe.save_pretrained(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
# 4. Verify it fell back to default_blocks_name and has correct blocks
assert loaded_pipe.__class__.__name__ == pipe.__class__.__name__
assert loaded_pipe._blocks.__class__.__name__ == pipe._blocks.__class__.__name__
assert len(loaded_pipe._blocks.sub_blocks) == len(pipe._blocks.sub_blocks)

View File

@@ -55,7 +55,7 @@ class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline):
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2_5_TransferWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"controls"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
@@ -176,15 +176,19 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
else:
generator = torch.Generator(device=device).manual_seed(seed)
controls_generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"controls": [torch.randn(3, 32, 32, generator=controls_generator) for _ in range(5)],
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
"num_frames": 3,
"num_frames_per_chunk": 16,
"max_sequence_length": 16,
"output_type": "pt",
}
@@ -212,6 +216,56 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_inference_autoregressive_multi_chunk(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames"] = 5
inputs["num_frames_per_chunk"] = 3
inputs["num_ar_conditional_frames"] = 1
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_inference_autoregressive_multi_chunk_no_condition_frames(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames"] = 5
inputs["num_frames_per_chunk"] = 3
inputs["num_ar_conditional_frames"] = 0
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_num_frames_per_chunk_above_rope_raises(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_frames_per_chunk"] = 17
with self.assertRaisesRegex(ValueError, "too large for RoPE setting"):
pipe(**inputs)
def test_inference_with_controls(self):
"""Test inference with control inputs (ControlNet)."""
device = "cpu"
@@ -222,13 +276,13 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
# Add control video input - should be a video tensor
inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width
inputs["controls"] = [torch.randn(3, 32, 32) for _ in range(5)] # list of 5 frames (C, H, W)
inputs["controls_conditioning_scale"] = 1.0
inputs["num_frames"] = None
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_callback_inputs(self):

View File

@@ -24,7 +24,8 @@ from diffusers import (
LTX2ImageToVideoPipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
@@ -174,6 +175,15 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return components
def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1):
upsampler = LTX2LatentUpsamplerModel(
in_channels=in_channels,
mid_channels=mid_channels,
num_blocks_per_stage=num_blocks_per_stage,
)
return upsampler
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
@@ -287,5 +297,60 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_two_stages_inference_with_upsampler(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "latent"
first_stage_output = pipe(**inputs)
video_latent = first_stage_output.frames
audio_latent = first_stage_output.audio
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1])
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler)
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0]
self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32))
inputs["latents"] = upscaled_video_latent
inputs["audio_latents"] = audio_latent
inputs["output_type"] = "pt"
second_stage_output = pipe(**inputs)
video = second_stage_output.frames
audio = second_stage_output.audio
self.assertEqual(video.shape, (1, 5, 3, 64, 64))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079
]
)
expected_audio_slice = torch.tensor(
[
0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)

View File

@@ -74,7 +74,7 @@ if is_torchao_available():
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
@@ -132,7 +132,7 @@ class TorchAoConfigTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
@@ -587,7 +587,7 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
@@ -698,23 +698,22 @@ class TorchAoSerializationTest(unittest.TestCase):
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
},
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())},
)
@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
pipe = self._init_pipeline(self.quantization_config, torch.bfloat16)
pipe.enable_model_cpu_offload()
# No compilation because it fails with:
# RuntimeError: _apply(): Couldn't swap Linear.weight
super().test_torch_compile_with_cpu_offload()
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
@parameterized.expand([False, True])
@unittest.skip(
@@ -745,7 +744,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
@@ -907,7 +906,7 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):