mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-08 20:05:05 +08:00
Compare commits
8 Commits
main
...
modular-do
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86fc6691cb | ||
|
|
7224beb036 | ||
|
|
64dba68e0a | ||
|
|
98ea6e0b2e | ||
|
|
64a90fc2e2 | ||
|
|
7fdddf012e | ||
|
|
24cbb354c0 | ||
|
|
025dfd4c67 |
@@ -12,85 +12,179 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# ComponentsManager
|
||||
|
||||
The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), and supports offloading.
|
||||
The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), prevents duplicate model instances, and supports offloading.
|
||||
|
||||
This guide will show you how to use [`ComponentsManager`] to manage components and device memory.
|
||||
|
||||
## Connect to a pipeline
|
||||
## Add a component
|
||||
|
||||
Create a [`ComponentsManager`] and pass it to a [`ModularPipeline`] with either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
|
||||
The [`ComponentsManager`] should be created alongside a [`ModularPipeline`] in either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
|
||||
|
||||
> [!TIP]
|
||||
> The `collection` parameter is optional but makes it easier to organize and manage components.
|
||||
|
||||
<hfoptions id="create">
|
||||
<hfoption id="from_pretrained">
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
import torch
|
||||
|
||||
manager = ComponentsManager()
|
||||
pipe = ModularPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", components_manager=manager)
|
||||
pipe.load_components(torch_dtype=torch.bfloat16)
|
||||
comp = ComponentsManager()
|
||||
pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="init_pipeline">
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipelineBlocks, ComponentsManager
|
||||
import torch
|
||||
manager = ComponentsManager()
|
||||
blocks = ModularPipelineBlocks.from_pretrained("diffusers/Florence2-image-Annotator", trust_remote_code=True)
|
||||
pipe= blocks.init_pipeline(components_manager=manager)
|
||||
pipe.load_components(torch_dtype=torch.bfloat16)
|
||||
from diffusers import ComponentsManager
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
|
||||
|
||||
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
components = ComponentsManager()
|
||||
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Components loaded by the pipeline are automatically registered in the manager. You can inspect them right away.
|
||||
|
||||
## Inspect components
|
||||
|
||||
Print the [`ComponentsManager`] to see all registered components, including their class, device placement, dtype, memory size, and load ID.
|
||||
|
||||
The output below corresponds to the `from_pretrained` example above.
|
||||
Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_components`]. The example below uses [`~ModularPipeline.load_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection
|
||||
|
||||
```py
|
||||
Components:
|
||||
=============================================================================================================================
|
||||
Models:
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
text_encoder_140458257514752 | Qwen3Model | cpu | torch.bfloat16 | 7.49 | Tongyi-MAI/Z-Image-Turbo|text_encoder|null|null
|
||||
vae_140458257515376 | AutoencoderKL | cpu | torch.bfloat16 | 0.16 | Tongyi-MAI/Z-Image-Turbo|vae|null|null
|
||||
transformer_140458257515616 | ZImageTransformer2DModel | cpu | torch.bfloat16 | 11.46 | Tongyi-MAI/Z-Image-Turbo|transformer|null|null
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Other Components:
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
ID | Class | Collection
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
scheduler_140461023555264 | FlowMatchEulerDiscreteScheduler | N/A
|
||||
tokenizer_140458256346432 | Qwen2Tokenizer | N/A
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
pipe.load_components()
|
||||
pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
|
||||
```
|
||||
|
||||
The table shows models (with device, dtype, and memory info) separately from other components like schedulers and tokenizers. If any models have LoRA adapters, IP-Adapters, or quantization applied, that information is displayed in an additional section at the bottom.
|
||||
Use the [`~ModularPipeline.null_component_names`] property to identify any components that need to be loaded, retrieve them with [`~ComponentsManager.get_components_by_names`], and then call [`~ModularPipeline.update_components`] to add the missing components.
|
||||
|
||||
```py
|
||||
pipe2.null_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
|
||||
|
||||
comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
|
||||
pipe2.update_components(**comp_dict)
|
||||
```
|
||||
|
||||
To add individual components, use the [`~ComponentsManager.add`] method. This registers a component with a unique id.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
|
||||
component_id = comp.add("text_encoder", text_encoder)
|
||||
comp
|
||||
```
|
||||
|
||||
Use [`~ComponentsManager.remove`] to remove a component using their id.
|
||||
|
||||
```py
|
||||
comp.remove("text_encoder_139917733042864")
|
||||
```
|
||||
|
||||
## Retrieve a component
|
||||
|
||||
The [`ComponentsManager`] provides several methods to retrieve registered components.
|
||||
|
||||
### get_one
|
||||
|
||||
The [`~ComponentsManager.get_one`] method returns a single component and supports pattern matching for the `name` parameter. If multiple components match, [`~ComponentsManager.get_one`] returns an error.
|
||||
|
||||
| Pattern | Example | Description |
|
||||
|-------------|----------------------------------|-------------------------------------------|
|
||||
| exact | `comp.get_one(name="unet")` | exact name match |
|
||||
| wildcard | `comp.get_one(name="unet*")` | names starting with "unet" |
|
||||
| exclusion | `comp.get_one(name="!unet")` | exclude components named "unet" |
|
||||
| or | `comp.get_one(name="unet|vae")` | name is "unet" or "vae" |
|
||||
|
||||
[`~ComponentsManager.get_one`] also filters components by the `collection` argument or `load_id` argument.
|
||||
|
||||
```py
|
||||
comp.get_one(name="unet", collection="sdxl")
|
||||
```
|
||||
|
||||
### get_components_by_names
|
||||
|
||||
The [`~ComponentsManager.get_components_by_names`] method accepts a list of names and returns a dictionary mapping names to components. This is especially useful with [`ModularPipeline`] since they provide lists of required component names and the returned dictionary can be passed directly to [`~ModularPipeline.update_components`].
|
||||
|
||||
```py
|
||||
component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
|
||||
{"text_encoder": component1, "unet": component2, "vae": component3}
|
||||
```
|
||||
|
||||
## Duplicate detection
|
||||
|
||||
It is recommended to load model components with [`ComponentSpec`] to assign components with a unique id that encodes their loading parameters. This allows [`ComponentsManager`] to automatically detect and prevent duplicate model instances even when different objects represent the same underlying checkpoint.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, ComponentsManager
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
comp = ComponentsManager()
|
||||
|
||||
# Create ComponentSpec for the first text encoder
|
||||
spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
|
||||
# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from the same repo/subfolder)
|
||||
spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
|
||||
|
||||
# Load and add both components - the manager will detect they're the same model
|
||||
comp.add("text_encoder", spec.load())
|
||||
comp.add("text_encoder_duplicated", spec_duplicated.load())
|
||||
```
|
||||
|
||||
This returns a warning with instructions for removing the duplicate.
|
||||
|
||||
```py
|
||||
ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('<component_id>')`.
|
||||
'text_encoder_duplicated_139917580682672'
|
||||
```
|
||||
|
||||
You could also add a component without using [`ComponentSpec`] and duplicate detection still works in most cases even if you're adding the same component under a different name.
|
||||
|
||||
However, [`ComponentManager`] can't detect duplicates when you load the same component into different objects. In this case, you should load a model with [`ComponentSpec`].
|
||||
|
||||
```py
|
||||
text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
|
||||
comp.add("text_encoder", text_encoder_2)
|
||||
'text_encoder_139917732983664'
|
||||
```
|
||||
|
||||
## Collections
|
||||
|
||||
Collections are labels assigned to components for better organization and management. Add a component to a collection with the `collection` argument in [`~ComponentsManager.add`].
|
||||
|
||||
Only one component per name is allowed in each collection. Adding a second component with the same name automatically removes the first component.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, ComponentsManager
|
||||
|
||||
comp = ComponentsManager()
|
||||
# Create ComponentSpec for the first UNet
|
||||
spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
|
||||
# Create ComponentSpec for a different UNet
|
||||
spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
|
||||
|
||||
# Add both UNets to the same collection - the second one will replace the first
|
||||
comp.add("unet", spec.load(), collection="sdxl")
|
||||
comp.add("unet", spec2.load(), collection="sdxl")
|
||||
```
|
||||
|
||||
This makes it convenient to work with node-based systems because you can:
|
||||
|
||||
- Mark all models as loaded from one node with the `collection` label.
|
||||
- Automatically replace models when new checkpoints are loaded under the same name.
|
||||
- Batch delete all models in a collection when a node is removed.
|
||||
|
||||
## Offloading
|
||||
|
||||
The [`~ComponentsManager.enable_auto_cpu_offload`] method is a global offloading strategy that works across all models regardless of which pipeline is using them. Once enabled, you don't need to worry about device placement if you add or remove components.
|
||||
|
||||
```py
|
||||
manager.enable_auto_cpu_offload(device="cuda")
|
||||
comp.enable_auto_cpu_offload(device="cuda")
|
||||
```
|
||||
|
||||
All models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low.
|
||||
|
||||
Call [`~ComponentsManager.disable_auto_cpu_offload`] to disable offloading.
|
||||
|
||||
```py
|
||||
manager.disable_auto_cpu_offload()
|
||||
```
|
||||
You can set your own rules for which models to offload first.
|
||||
|
||||
@@ -12,27 +12,28 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# ModularPipeline
|
||||
|
||||
[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`]'s into an executable pipeline that loads models and performs the computation steps defined in the block. It is the main interface for running a pipeline and it is very similar to the [`DiffusionPipeline`] API.
|
||||
[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`] into an executable pipeline that loads models and performs the computation steps defined in the blocks. It is the main interface for running a pipeline and the API is very similar to [`DiffusionPipeline`] but with a few key differences.
|
||||
|
||||
The main difference is to include an expected `output` argument in the pipeline.
|
||||
- **Loading is lazy.** With [`DiffusionPipeline`], [`~DiffusionPipeline.from_pretrained`] creates the pipeline and loads all models at the same time. With [`ModularPipeline`], creating and loading are two separate steps: [`~ModularPipeline.from_pretrained`] reads the configuration and knows where to load each component from, but doesn't actually load the model weights. You load the models later with [`~ModularPipeline.load_components`], which is where you pass loading arguments like `torch_dtype` and `quantization_config`.
|
||||
|
||||
- **Two ways to create a pipeline.** You can use [`~ModularPipeline.from_pretrained`] with an existing diffusers model repository — it automatically maps to the default pipeline blocks and then converts to a [`ModularPipeline`] with no extra setup. Currently supported models include SDXL, Wan, Qwen, Z-Image, Flux, and Flux2. You can also assemble your own pipeline from [`ModularPipelineBlocks`] and convert it with the [`~ModularPipelineBlocks.init_pipeline`] method (see [Creating a pipeline](#creating-a-pipeline) for more details).
|
||||
|
||||
- **Running the pipeline is the same.** Once loaded, you call the pipeline with the same arguments you're used to. A single [`ModularPipeline`] can support multiple workflows (text-to-image, image-to-image, inpainting, etc.) when the pipeline blocks use [`AutoPipelineBlocks`](./auto_pipeline) to automatically select the workflow based on your inputs.
|
||||
|
||||
Below are complete examples for text-to-image, image-to-image, and inpainting with SDXL.
|
||||
|
||||
<hfoptions id="example">
|
||||
<hfoption id="text-to-image">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
|
||||
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
|
||||
image.save("modular_t2i_out.png")
|
||||
```
|
||||
|
||||
@@ -41,21 +42,17 @@ image.save("modular_t2i_out.png")
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
from diffusers import ModularPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
init_image = load_image(url)
|
||||
prompt = "a dog catching a frisbee in the jungle"
|
||||
image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0]
|
||||
image = pipeline(prompt=prompt, image=init_image, strength=0.8).images[0]
|
||||
image.save("modular_i2i_out.png")
|
||||
```
|
||||
|
||||
@@ -64,15 +61,10 @@ image.save("modular_i2i_out.png")
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers import ModularPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
@@ -83,276 +75,353 @@ init_image = load_image(img_url)
|
||||
mask_image = load_image(mask_url)
|
||||
|
||||
prompt = "A deep sea diver floating"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0]
|
||||
image.save("moduar_inpaint_out.png")
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85).images[0]
|
||||
image.save("modular_inpaint_out.png")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This guide will show you how to create a [`ModularPipeline`] and manage the components in it.
|
||||
|
||||
## Adding blocks
|
||||
|
||||
Blocks are [`InsertableDict`] objects that can be inserted at specific positions, providing a flexible way to mix-and-match blocks.
|
||||
|
||||
Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.insert`] on either the block class or `sub_blocks` attribute to add a block.
|
||||
|
||||
```py
|
||||
# BLOCKS is dict of block classes, you need to add class to it
|
||||
BLOCKS.insert("block_name", BlockClass, index)
|
||||
# sub_blocks attribute contains instance, add a block instance to the attribute
|
||||
t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
|
||||
```
|
||||
|
||||
Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.pop`] on either the block class or `sub_blocks` attribute to remove a block.
|
||||
|
||||
```py
|
||||
# remove a block class from preset
|
||||
BLOCKS.pop("text_encoder")
|
||||
# split out a block instance on its own
|
||||
text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
|
||||
```
|
||||
|
||||
Swap blocks by setting the existing block to the new block.
|
||||
|
||||
```py
|
||||
# Replace block class in preset
|
||||
BLOCKS["prepare_latents"] = CustomPrepareLatents
|
||||
# Replace in sub_blocks attribute using an block instance
|
||||
t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
|
||||
```
|
||||
This guide will show you how to create a [`ModularPipeline`], manage the components in it, and run it.
|
||||
|
||||
## Creating a pipeline
|
||||
|
||||
There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
|
||||
There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] with [`~ModularPipelineBlocks.init_pipeline`], or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
|
||||
|
||||
You should also initialize a [`ComponentsManager`] to handle device placement and memory and component management.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [ComponentsManager](./components_manager) doc for more details about how it can help manage components across different workflows.
|
||||
|
||||
<hfoptions id="create">
|
||||
<hfoption id="ModularPipelineBlocks">
|
||||
### init_pipeline
|
||||
|
||||
Use the [`~ModularPipelineBlocks.init_pipeline`] method to create a [`ModularPipeline`] from the component and configuration specifications. This method loads the *specifications* from a `modular_model_index.json` file, but it doesn't load the *models* yet.
|
||||
[`~ModularPipelineBlocks.init_pipeline`] converts any [`ModularPipelineBlocks`] into a [`ModularPipeline`].
|
||||
|
||||
Let's define a minimal block to see how it works:
|
||||
|
||||
```py
|
||||
from diffusers import ComponentsManager
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
|
||||
from transformers import CLIPTextModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
|
||||
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
class MyBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="text_encoder",
|
||||
type_hint=CLIPTextModel,
|
||||
pretrained_model_name_or_path="openai/clip-vit-large-patch14",
|
||||
),
|
||||
]
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
components = ComponentsManager()
|
||||
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
return components, state
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="from_pretrained">
|
||||
Call [`~ModularPipelineBlocks.init_pipeline`] to convert it into a pipeline. The `blocks` attribute on the pipeline is the blocks it was created from — it determines the expected inputs, outputs, and computation logic.
|
||||
|
||||
The [`~ModularPipeline.from_pretrained`] method creates a [`ModularPipeline`] from a modular repository on the Hub.
|
||||
```py
|
||||
block = MyBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.blocks
|
||||
```
|
||||
|
||||
```
|
||||
MyBlock {
|
||||
"_class_name": "MyBlock",
|
||||
"_diffusers_version": "0.37.0.dev0"
|
||||
}
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Blocks are mutable — you can freely add, remove, or swap blocks before creating a pipeline. However, once a pipeline is created, modifying `pipeline.blocks` won't affect the pipeline because it returns a copy. If you want a different block structure, create a new pipeline after modifying the blocks.
|
||||
|
||||
When you call [`~ModularPipelineBlocks.init_pipeline`] without a repository, it uses the `pretrained_model_name_or_path` defined in the block's [`ComponentSpec`] to determine where to load each component from. Printing the pipeline shows the component loading configuration.
|
||||
|
||||
```py
|
||||
pipe
|
||||
ModularPipeline {
|
||||
"_blocks_class_name": "MyBlock",
|
||||
"_class_name": "ModularPipeline",
|
||||
"_diffusers_version": "0.37.0.dev0",
|
||||
"text_encoder": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"revision": null,
|
||||
"subfolder": "",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If you pass a repository to [`~ModularPipelineBlocks.init_pipeline`], it overrides the loading path by matching your block's components against the pipeline config in that repository (`model_index.json` or `modular_model_index.json`).
|
||||
|
||||
In the example below, the `pretrained_model_name_or_path` will be updated to `"stabilityai/stable-diffusion-xl-base-1.0"`.
|
||||
|
||||
```py
|
||||
pipe = block.init_pipeline("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe
|
||||
ModularPipeline {
|
||||
"_blocks_class_name": "MyBlock",
|
||||
"_class_name": "ModularPipeline",
|
||||
"_diffusers_version": "0.37.0.dev0",
|
||||
"text_encoder": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"revision": null,
|
||||
"subfolder": "text_encoder",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If a component in your block doesn't exist in the repository, it remains `null` and is skipped during [`~ModularPipeline.load_components`].
|
||||
|
||||
### from_pretrained
|
||||
|
||||
[`~ModularPipeline.from_pretrained`] is a convenient way to create a [`ModularPipeline`] without defining blocks yourself.
|
||||
|
||||
It works with three types of repositories.
|
||||
|
||||
**A regular diffusers repository.** Pass any supported model repository and it automatically maps to the default pipeline blocks. Currently supported models include SDXL, Wan, Qwen, Z-Image, Flux, and Flux2.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
Add the `trust_remote_code` argument to load a custom [`ModularPipeline`].
|
||||
**A modular repository.** These repositories contain a `modular_model_index.json` that specifies where to load each component from — the components can come from different repositories and the modular repository itself may not contain any model weights. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from one repository and the remaining components from another. See [Modular repository](#modular-repository) for more details on the format.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
modular_repo_id = "YiYiXu/modular-diffdiff-0704"
|
||||
diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"diffusers/flux2-bnb-4bit-modular", components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
**A modular repository with custom code.** Some repositories include custom pipeline blocks alongside the loading configuration. Add `trust_remote_code=True` to load them. See [Custom blocks](./custom_blocks) for how to create your own.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"diffusers/Florence2-image-Annotator", trust_remote_code=True, components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
## Loading components
|
||||
|
||||
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load all components with [`~ModularPipeline.load_components`] or only load specific components with [`~ModularPipeline.load_components`].
|
||||
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load components with [`~ModularPipeline.load_components`].
|
||||
|
||||
<hfoptions id="load">
|
||||
<hfoption id="load_components">
|
||||
This will load all the components that have a valid loading spec.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
t2i_pipeline.load_components(torch_dtype=torch.float16)
|
||||
t2i_pipeline.to("cuda")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="load_components">
|
||||
|
||||
The example below only loads the UNet and VAE.
|
||||
You can also load specific components by name. The example below only loads the text_encoder.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
|
||||
pipeline.load_components(names=["text_encoder"], torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Print the pipeline to inspect the loaded pretrained components.
|
||||
After loading, printing the pipeline shows which components are loaded — the first two fields change from `null` to the component's library and class.
|
||||
|
||||
```py
|
||||
t2i_pipeline
|
||||
pipeline
|
||||
```
|
||||
|
||||
This should match the `modular_model_index.json` file from the modular repository a pipeline is initialized from. If a pipeline doesn't need a component, it won't be included even if it exists in the modular repository.
|
||||
|
||||
To modify where components are loaded from, edit the `modular_model_index.json` file in the repository and change it to your desired loading path. The example below loads a UNet from a different repository.
|
||||
|
||||
```json
|
||||
# original
|
||||
"unet": [
|
||||
null, null,
|
||||
{
|
||||
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"subfolder": "unet",
|
||||
"variant": "fp16"
|
||||
}
|
||||
```
|
||||
# text_encoder is loaded - shows library and class
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel",
|
||||
{ ... }
|
||||
]
|
||||
|
||||
# modified
|
||||
# unet is not loaded yet - still null
|
||||
"unet": [
|
||||
null, null,
|
||||
{
|
||||
"repo": "RunDiffusion/Juggernaut-XL-v9",
|
||||
"subfolder": "unet",
|
||||
"variant": "fp16"
|
||||
}
|
||||
null,
|
||||
null,
|
||||
{ ... }
|
||||
]
|
||||
```
|
||||
|
||||
### Component loading status
|
||||
|
||||
The pipeline properties below provide more information about which components are loaded.
|
||||
|
||||
Use `component_names` to return all expected components.
|
||||
Loading keyword arguments like `torch_dtype`, `variant`, `revision`, and `quantization_config` are passed through to `from_pretrained()` for each component. You can pass a single value to apply to all components, or a dict to set per-component values.
|
||||
|
||||
```py
|
||||
t2i_pipeline.component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
|
||||
# apply bfloat16 to all components
|
||||
pipeline.load_components(torch_dtype=torch.bfloat16)
|
||||
|
||||
# different dtypes per component
|
||||
pipeline.load_components(torch_dtype={"transformer": torch.bfloat16, "default": torch.float32})
|
||||
```
|
||||
|
||||
Use `null_component_names` to return components that aren't loaded yet. Load these components with [`~ModularPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
t2i_pipeline.null_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
|
||||
```
|
||||
|
||||
Use `pretrained_component_names` to return components that will be loaded from pretrained models.
|
||||
|
||||
```py
|
||||
t2i_pipeline.pretrained_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
|
||||
```
|
||||
|
||||
Use `config_component_names` to return components that are created with the default config (not loaded from a modular repository). Components from a config aren't included because they are already initialized during pipeline creation. This is why they aren't listed in `null_component_names`.
|
||||
|
||||
```py
|
||||
t2i_pipeline.config_component_names
|
||||
['guider', 'image_processor']
|
||||
```
|
||||
Note that [`~ModularPipeline.load_components`] only loads components that haven't been loaded yet and have a valid loading spec. This means if you've already set a component on the pipeline, calling [`~ModularPipeline.load_components`] again won't reload it.
|
||||
|
||||
## Updating components
|
||||
|
||||
Components may be updated depending on whether it is a *pretrained component* or a *config component*.
|
||||
[`~ModularPipeline.update_components`] replaces a component on the pipeline with a new one. When a component is updated, the loading specifications are also updated in the pipeline config and [`~ModularPipeline.load_components`] will skip it on subsequent calls.
|
||||
|
||||
> [!WARNING]
|
||||
> A component may change from pretrained to config when updating a component. The component type is initially defined in a block's `expected_components` field.
|
||||
### From AutoModel
|
||||
|
||||
A pretrained component is updated with [`ComponentSpec`] whereas a config component is updated by eihter passing the object directly or with [`ComponentSpec`].
|
||||
|
||||
The [`ComponentSpec`] shows `default_creation_method="from_pretrained"` for a pretrained component shows `default_creation_method="from_config` for a config component.
|
||||
|
||||
To update a pretrained component, create a [`ComponentSpec`] with the name of the component and where to load it from. Use the [`~ComponentSpec.load`] method to load the component.
|
||||
You can pass a model object loaded with `AutoModel.from_pretrained()`. Models loaded this way are automatically tagged with their loading information.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, UNet2DConditionModel
|
||||
from diffusers import AutoModel
|
||||
|
||||
unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
The [`~ModularPipeline.update_components`] method replaces the component with a new one.
|
||||
|
||||
```py
|
||||
t2i_pipeline.update_components(unet=unet2)
|
||||
```
|
||||
|
||||
When a component is updated, the loading specifications are also updated in the pipeline config.
|
||||
|
||||
### Component extraction and modification
|
||||
|
||||
When you use [`~ComponentSpec.load`], the new component maintains its loading specifications. This makes it possible to extract the specification and recreate the component.
|
||||
|
||||
```py
|
||||
spec = ComponentSpec.from_component("unet", unet2)
|
||||
spec
|
||||
ComponentSpec(name='unet', type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
|
||||
unet2_recreated = spec.load(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
The [`~ModularPipeline.get_component_spec`] method gets a copy of the current component specification to modify or update.
|
||||
|
||||
```py
|
||||
unet_spec = t2i_pipeline.get_component_spec("unet")
|
||||
unet_spec
|
||||
ComponentSpec(
|
||||
name='unet',
|
||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||
subfolder='unet',
|
||||
variant='fp16',
|
||||
default_creation_method='from_pretrained'
|
||||
unet = AutoModel.from_pretrained(
|
||||
"RunDiffusion/Juggernaut-XL-v9", subfolder="unet", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.update_components(unet=unet)
|
||||
```
|
||||
|
||||
### From ComponentSpec
|
||||
|
||||
Use [`~ModularPipeline.get_component_spec`] to get a copy of the current component specification, modify it, and load a new component.
|
||||
|
||||
```py
|
||||
unet_spec = pipeline.get_component_spec("unet")
|
||||
|
||||
# modify to load from a different repository
|
||||
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
unet_spec.pretrained_model_name_or_path = "RunDiffusion/Juggernaut-XL-v9"
|
||||
|
||||
# load component with modified spec
|
||||
# load and update
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
pipeline.update_components(unet=unet)
|
||||
```
|
||||
|
||||
You can also create a [`ComponentSpec`] from scratch.
|
||||
|
||||
Not all components are loaded from pretrained weights — some are created from a config (listed under `pipeline.config_component_names`). For these, use [`~ComponentSpec.create`] instead of [`~ComponentSpec.load`].
|
||||
|
||||
```py
|
||||
guider_spec = pipeline.get_component_spec("guider")
|
||||
guider_spec.config = {"guidance_scale": 5.0}
|
||||
guider = guider_spec.create()
|
||||
pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
Or simply pass the object directly.
|
||||
|
||||
```py
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
|
||||
guider = ClassifierFreeGuidance(guidance_scale=5.0)
|
||||
pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
See the [Guiders](./guiders) guide for more details on available guiders and how to configure them.
|
||||
|
||||
## Splitting a pipeline into stages
|
||||
|
||||
Since blocks are composable, you can take a pipeline apart and reconstruct it into separate pipelines for each stage. The example below shows how we can separate the text encoder block from the rest of the pipeline, so you can encode the prompt independently and pass the embeddings to the main pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
import torch
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
||||
|
||||
# get the blocks and separate out the text encoder
|
||||
blocks = ModularPipeline.from_pretrained(repo_id).blocks
|
||||
text_block = blocks.sub_blocks.pop("text_encoder")
|
||||
|
||||
# use ComponentsManager to handle offloading across multiple pipelines
|
||||
manager = ComponentsManager()
|
||||
manager.enable_auto_cpu_offload(device=device)
|
||||
|
||||
# create separate pipelines for each stage
|
||||
text_encoder_pipeline = text_block.init_pipeline(repo_id, components_manager=manager)
|
||||
pipeline = blocks.init_pipeline(repo_id, components_manager=manager)
|
||||
|
||||
# encode text
|
||||
text_encoder_pipeline.load_components(torch_dtype=dtype)
|
||||
text_embeddings = text_encoder_pipeline(prompt="a cat").get_by_kwargs("denoiser_input_fields")
|
||||
|
||||
# denoise and decode
|
||||
pipeline.load_components(torch_dtype=dtype)
|
||||
output = pipeline(
|
||||
**text_embeddings,
|
||||
num_inference_steps=4,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
[`ComponentsManager`] handles memory across multiple pipelines. Unlike the offloading strategies in [`DiffusionPipeline`] that follow a fixed order, [`ComponentsManager`] makes offloading decisions dynamically each time a model forward pass runs, based on the current memory situation. This means it works regardless of how many pipelines you create or what order you run them in. See the [ComponentsManager](./components_manager) guide for more details.
|
||||
|
||||
If pipeline stages share components (e.g., the same VAE used for encoding and decoding), you can use [`~ModularPipeline.update_components`] to pass an already-loaded component to another pipeline instead of loading it again.
|
||||
|
||||
## Modular repository
|
||||
|
||||
A repository is required if the pipeline blocks use *pretrained components*. The repository supplies loading specifications and metadata.
|
||||
|
||||
[`ModularPipeline`] specifically requires *modular repositories* (see [example repository](https://huggingface.co/YiYiXu/modular-diffdiff)) which are more flexible than a typical repository. It contains a `modular_model_index.json` file containing the following 3 elements.
|
||||
[`ModularPipeline`] works with regular diffusers repositories out of the box. However, you can also create a *modular repository* for more flexibility. A modular repository contains a `modular_model_index.json` file containing the following 3 elements.
|
||||
|
||||
- `library` and `class` shows which library the component was loaded from and it's class. If `null`, the component hasn't been loaded yet.
|
||||
- `library` and `class` shows which library the component was loaded from and its class. If `null`, the component hasn't been loaded yet.
|
||||
- `loading_specs_dict` contains the information required to load the component such as the repository and subfolder it is loaded from.
|
||||
|
||||
Unlike standard repositories, a modular repository can fetch components from different repositories based on the `loading_specs_dict`. Components don't need to exist in the same repository.
|
||||
The key advantage of a modular repository is that components can be loaded from different repositories. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from `diffusers/FLUX.2-dev-bnb-4bit` while loading the remaining components from `black-forest-labs/FLUX.2-dev`.
|
||||
|
||||
A modular repository may contain custom code for loading a [`ModularPipeline`]. This allows you to use specialized blocks that aren't native to Diffusers.
|
||||
To convert a regular diffusers repository into a modular one, create the pipeline using the regular repository, and then push to the Hub. The saved repository will contain a `modular_model_index.json` with all the loading specifications.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
# load from a regular repo
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
|
||||
# push as a modular repository
|
||||
pipeline.save_pretrained("local/path", repo_id="my-username/sdxl-modular", push_to_hub=True)
|
||||
```
|
||||
|
||||
A modular repository can also include custom pipeline blocks as Python code. This allows you to share specialized blocks that aren't native to Diffusers. For example, [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator) contains custom blocks alongside the loading configuration:
|
||||
|
||||
```
|
||||
modular-diffdiff-0704/
|
||||
Florence2-image-Annotator/
|
||||
├── block.py # Custom pipeline blocks implementation
|
||||
├── config.json # Pipeline configuration and auto_map
|
||||
├── mellon_config.json # UI configuration for Mellon
|
||||
└── modular_model_index.json # Component loading specifications
|
||||
```
|
||||
|
||||
The [config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file contains an `auto_map` key that points to where a custom block is defined in `block.py`.
|
||||
The `config.json` file contains an `auto_map` key that tells [`ModularPipeline`] where to find the custom blocks:
|
||||
|
||||
```json
|
||||
{
|
||||
"_class_name": "DiffDiffBlocks",
|
||||
"_class_name": "Florence2AnnotatorBlocks",
|
||||
"auto_map": {
|
||||
"ModularPipelineBlocks": "block.DiffDiffBlocks"
|
||||
"ModularPipelineBlocks": "block.Florence2AnnotatorBlocks"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Load custom code repositories with `trust_remote_code=True` as shown in [from_pretrained](#from_pretrained). See [Custom blocks](./custom_blocks) for how to create and share your own.
|
||||
@@ -417,7 +417,6 @@ else:
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
@@ -434,13 +433,8 @@ else:
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"Wan22Blocks",
|
||||
"Wan22Image2VideoBlocks",
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanBlocks",
|
||||
"WanImage2VideoAutoBlocks",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanModularPipeline",
|
||||
"ZImageAutoBlocks",
|
||||
"ZImageModularPipeline",
|
||||
@@ -1162,7 +1156,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinBaseModularPipeline,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
@@ -1179,13 +1172,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
Wan22Blocks,
|
||||
Wan22Image2VideoBlocks,
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanBlocks,
|
||||
WanImage2VideoAutoBlocks,
|
||||
WanImage2VideoModularPipeline,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanModularPipeline,
|
||||
ZImageAutoBlocks,
|
||||
ZImageModularPipeline,
|
||||
|
||||
@@ -45,16 +45,7 @@ else:
|
||||
"InsertableDict",
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = [
|
||||
"WanBlocks",
|
||||
"Wan22Blocks",
|
||||
"WanImage2VideoAutoBlocks",
|
||||
"Wan22Image2VideoBlocks",
|
||||
"WanModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
@@ -67,7 +58,6 @@ else:
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
@@ -98,7 +88,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinBaseModularPipeline,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
@@ -123,16 +112,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import (
|
||||
Wan22Blocks,
|
||||
Wan22Image2VideoBlocks,
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanBlocks,
|
||||
WanImage2VideoAutoBlocks,
|
||||
WanImage2VideoModularPipeline,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
||||
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -55,11 +55,7 @@ else:
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -105,7 +101,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
@@ -57,36 +59,47 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(Flux2ModularPipeline):
|
||||
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein (distilled model).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinAutoBlocks"
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
return False
|
||||
|
||||
requires_unconditional_embeds = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
|
||||
class Flux2KleinBaseModularPipeline(Flux2ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein (base model).
|
||||
A ModularPipeline for Flux2-Klein.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinAutoBlocks"
|
||||
else:
|
||||
return "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
|
||||
@@ -156,12 +156,6 @@ MELLON_PARAM_TEMPLATES = {
|
||||
"display": "slider",
|
||||
"required_block_params": ["layers"],
|
||||
},
|
||||
"output_type": {
|
||||
"label": "Output Type",
|
||||
"type": "dropdown",
|
||||
"default": "np",
|
||||
"options": ["np", "pil", "pt"],
|
||||
},
|
||||
# ControlNet
|
||||
"controlnet_conditioning_scale": {
|
||||
"label": "Controlnet Conditioning Scale",
|
||||
|
||||
@@ -54,61 +54,19 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# map regular pipeline to modular pipeline class name
|
||||
|
||||
|
||||
def _create_default_map_fn(pipeline_class_name: str):
|
||||
"""Create a mapping function that always returns the same pipeline class."""
|
||||
|
||||
def _map_fn(config_dict=None):
|
||||
return pipeline_class_name
|
||||
|
||||
return _map_fn
|
||||
|
||||
|
||||
def _flux2_klein_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "Flux2KleinModularPipeline"
|
||||
|
||||
if "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinModularPipeline"
|
||||
else:
|
||||
return "Flux2KleinBaseModularPipeline"
|
||||
|
||||
|
||||
def _wan_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "WanModularPipeline"
|
||||
|
||||
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22ModularPipeline"
|
||||
else:
|
||||
return "WanModularPipeline"
|
||||
|
||||
|
||||
def _wan_i2v_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "WanImage2VideoModularPipeline"
|
||||
|
||||
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22Image2VideoModularPipeline"
|
||||
else:
|
||||
return "WanImage2VideoModularPipeline"
|
||||
|
||||
|
||||
MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
|
||||
("wan", _wan_map_fn),
|
||||
("wan-i2v", _wan_i2v_map_fn),
|
||||
("flux", _create_default_map_fn("FluxModularPipeline")),
|
||||
("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")),
|
||||
("flux2", _create_default_map_fn("Flux2ModularPipeline")),
|
||||
("flux2-klein", _flux2_klein_map_fn),
|
||||
("qwenimage", _create_default_map_fn("QwenImageModularPipeline")),
|
||||
("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")),
|
||||
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
|
||||
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
|
||||
("z-image", _create_default_map_fn("ZImageModularPipeline")),
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||
("wan", "WanModularPipeline"),
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("flux2-klein", "Flux2KleinModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
|
||||
("z-image", "ZImageModularPipeline"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -410,8 +368,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
|
||||
"""
|
||||
map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline"))
|
||||
pipeline_class_name = map_fn()
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
|
||||
@@ -1590,7 +1547,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if modular_config_dict is not None:
|
||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||
else:
|
||||
blocks_class_name = self.default_blocks_name
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
@@ -1662,6 +1619,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
return self.default_blocks_name
|
||||
|
||||
@classmethod
|
||||
def _load_pipeline_config(
|
||||
cls,
|
||||
@@ -1757,8 +1717,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline"))
|
||||
pipeline_class_name = map_fn(config_dict)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
else:
|
||||
@@ -2057,58 +2016,29 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
|
||||
Args:
|
||||
**kwargs: Component objects, ComponentSpec objects, or configuration values to update:
|
||||
- Component objects: Only supports components we can extract specs using
|
||||
`ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or
|
||||
ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`)
|
||||
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create()
|
||||
method to create a new component (e.g., `guider=ComponentSpec(name="guider",
|
||||
type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
|
||||
- Configuration values: Simple values to update configuration settings (e.g.,
|
||||
`requires_safety_checker=False`)
|
||||
|
||||
Raises:
|
||||
ValueError: If a component object is not supported in ComponentSpec.from_component() method:
|
||||
- nn.Module components without a valid `_diffusers_load_id` attribute
|
||||
- Non-ConfigMixin components without a valid `_diffusers_load_id` attribute
|
||||
**kwargs: Component objects or configuration values to update:
|
||||
- Component objects: Models loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()`
|
||||
are automatically tagged with loading information. ConfigMixin objects without weights (e.g.,
|
||||
schedulers, guiders) can be passed directly.
|
||||
- Configuration values: Simple values to update configuration settings
|
||||
(e.g., `requires_safety_checker=False`)
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Update multiple components at once
|
||||
# Update pretrrained model
|
||||
pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder)
|
||||
|
||||
# Update configuration values
|
||||
pipeline.update_components(requires_safety_checker=False)
|
||||
|
||||
# Update both components and configs together
|
||||
pipeline.update_components(unet=new_unet_model, requires_safety_checker=False)
|
||||
|
||||
# Update with ComponentSpec objects (from_config only)
|
||||
pipeline.update_components(
|
||||
guider=ComponentSpec(
|
||||
name="guider",
|
||||
type_hint=ClassifierFreeGuidance,
|
||||
config={"guidance_scale": 5.0},
|
||||
default_creation_method="from_config",
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been
|
||||
shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
|
||||
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
|
||||
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in
|
||||
update_components()
|
||||
- Components with trained weights should be loaded with `AutoModel.from_pretrained()` or
|
||||
`ComponentSpec.load()` so that loading specs are preserved for serialization.
|
||||
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly.
|
||||
"""
|
||||
|
||||
# extract component_specs_updates & config_specs_updates from `specs`
|
||||
passed_component_specs = {
|
||||
k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)
|
||||
}
|
||||
passed_components = {
|
||||
k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)
|
||||
}
|
||||
passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs}
|
||||
passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
|
||||
|
||||
for name, component in passed_components.items():
|
||||
@@ -2147,33 +2077,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||
|
||||
created_components = {}
|
||||
for name, component_spec in passed_component_specs.items():
|
||||
if component_spec.default_creation_method == "from_pretrained":
|
||||
raise ValueError(
|
||||
"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method"
|
||||
)
|
||||
created_components[name] = component_spec.create()
|
||||
current_component_spec = self._component_specs[name]
|
||||
# warn if type changed
|
||||
if current_component_spec.type_hint is not None and not isinstance(
|
||||
created_components[name], current_component_spec.type_hint
|
||||
):
|
||||
logger.info(
|
||||
f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
|
||||
)
|
||||
# update _component_specs based on the user passed component_spec
|
||||
self._component_specs[name] = component_spec
|
||||
self.register_components(**passed_components, **created_components)
|
||||
self.register_components(**passed_components)
|
||||
|
||||
config_to_register = {}
|
||||
for name, new_value in passed_config_values.items():
|
||||
# e.g. requires_aesthetics_score = False
|
||||
self._config_specs[name].default = new_value
|
||||
config_to_register[name] = new_value
|
||||
self.register_to_config(**config_to_register)
|
||||
|
||||
# YiYi TODO: support map for additional from_pretrained kwargs
|
||||
def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
|
||||
"""
|
||||
Load selected components from specs.
|
||||
|
||||
@@ -21,16 +21,16 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modular_blocks_wan"] = ["WanBlocks"]
|
||||
_import_structure["modular_blocks_wan22"] = ["Wan22Blocks"]
|
||||
_import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"]
|
||||
_import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"WanModularPipeline",
|
||||
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
|
||||
_import_structure["encoders"] = ["WanTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanAutoImageEncoderStep",
|
||||
"WanAutoVaeImageEncoderStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -39,16 +39,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_blocks_wan import WanBlocks
|
||||
from .modular_blocks_wan22 import Wan22Blocks
|
||||
from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks
|
||||
from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks
|
||||
from .modular_pipeline import (
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanImage2VideoModularPipeline,
|
||||
WanModularPipeline,
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .encoders import WanTextEncoderStep
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
)
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_condition_latents"],
|
||||
image_latent_inputs: List[str] = ["first_frame_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
@@ -294,16 +294,20 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
|
||||
a single string or list of strings. Defaults to ["image_condition_latents"].
|
||||
a single string or list of strings. Defaults to ["first_frame_latents"].
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to [].
|
||||
|
||||
Examples:
|
||||
# Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to
|
||||
process image latents and additional batch inputs WanAdditionalInputsStep(
|
||||
image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"]
|
||||
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
|
||||
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
@@ -553,3 +557,81 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked first frame latents and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
|
||||
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
|
||||
block_state.first_last_frame_latents = torch.concat(
|
||||
[mask_lat_size, block_state.first_last_frame_latents], dim=1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanVaeDecoderStep(ModularPipelineBlocks):
|
||||
class WanImageVaeDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -56,10 +56,7 @@ class WanVaeDecoderStep(ModularPipelineBlocks):
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
),
|
||||
InputParam(
|
||||
"output_type", default="np", type_hint=str, description="The output type of the decoded videos"
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -90,8 +87,7 @@ class WanVaeDecoderStep(ModularPipelineBlocks):
|
||||
latents = latents.to(vae_dtype)
|
||||
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
output_type = getattr(block_state, "output_type", "np")
|
||||
block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type=output_type)
|
||||
block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np")
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
|
||||
@@ -89,10 +89,52 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_condition_latents",
|
||||
"first_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.",
|
||||
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
|
||||
block_state.dtype
|
||||
)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"first_last_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
@@ -105,7 +147,7 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat(
|
||||
[block_state.latents, block_state.image_condition_latents], dim=1
|
||||
[block_state.latents, block_state.first_last_frame_latents], dim=1
|
||||
).to(block_state.dtype)
|
||||
return components, block_state
|
||||
|
||||
@@ -542,3 +584,29 @@ class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports image-to-video tasks for Wan2.2."
|
||||
)
|
||||
|
||||
|
||||
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanFLF2VLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_image": "image_embeds",
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanFLF2VLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports FLF2V tasks for wan2.1."
|
||||
)
|
||||
|
||||
@@ -468,7 +468,7 @@ class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanVaeEncoderStep(ModularPipelineBlocks):
|
||||
class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -493,7 +493,7 @@ class WanVaeEncoderStep(ModularPipelineBlocks):
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames", type_hint=int, default=81),
|
||||
InputParam("num_frames"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -564,51 +564,7 @@ class WanVaeEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked first frame latents and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
|
||||
block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
|
||||
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -634,7 +590,7 @@ class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
|
||||
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames", type_hint=int, default=81),
|
||||
InputParam("num_frames"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -711,49 +667,3 @@ class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int, required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
|
||||
block_state.image_condition_latents = torch.concat(
|
||||
[mask_lat_size, block_state.first_last_frame_latents], dim=1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
474
src/diffusers/modular_pipelines/wan/modular_blocks.py
Normal file
474
src/diffusers/modular_pipelines/wan/modular_blocks.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22DenoiseStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
WanDenoiseStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanFirstLastFrameImageEncoderStep,
|
||||
WanFirstLastFrameVaeImageEncoderStep,
|
||||
WanImageCropResizeStep,
|
||||
WanImageEncoderStep,
|
||||
WanImageResizeStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeImageEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# wan2.1
|
||||
# wan2.1: text2vid
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: image2video
|
||||
## image encoder
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "vae_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: FLF2v
|
||||
|
||||
|
||||
## image encoder
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_last_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: auto blocks
|
||||
## image encoder
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image Encoder step that encode the image to generate the image embeddings"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## denoise
|
||||
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
WanFLF2VCoreDenoiseStep,
|
||||
WanImage2VideoCoreDenoiseStep,
|
||||
WanCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["flf2v", "image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto pipeline blocks
|
||||
class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
WanAutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-video using Wan.\n"
|
||||
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# wan22
|
||||
# wan2.2: text2vid
|
||||
|
||||
|
||||
## denoise
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22DenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.2: image2video
|
||||
## denoise
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
Wan22Image2VideoCoreDenoiseStep,
|
||||
Wan22CoreDenoiseStep,
|
||||
]
|
||||
block_names = ["image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
Wan22AutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# presets for wan2.1 and wan2.2
|
||||
# YiYi Notes: should we move these to doc?
|
||||
# wan2.1
|
||||
TEXT2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", WanDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("image_encoder", WanImage2VideoImageEncoderStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
|
||||
("denoise", WanImage2VideoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
FLF2V_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("last_image_resize", WanImageCropResizeStep),
|
||||
("image_encoder", WanFLF2VImageEncoderStep),
|
||||
("vae_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
|
||||
("denoise", WanFLF2VDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("image_encoder", WanAutoImageEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# wan2.2 presets
|
||||
|
||||
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", Wan22AutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# presets all blocks (wan and wan22)
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"wan2.1": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS,
|
||||
"flf2v": FLF2V_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
},
|
||||
"wan2.2": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
|
||||
"auto": AUTO_BLOCKS_WAN22,
|
||||
},
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
WanDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanTextEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. BLOCKS (Wan2.1 text2video)
|
||||
# ====================
|
||||
|
||||
|
||||
class WanBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline blocks for Wan2.1.\n"
|
||||
+ "- `WanTextEncoderStep` is used to encode the text\n"
|
||||
+ "- `WanCoreDenoiseStep` is used to denoise the latents\n"
|
||||
+ "- `WanVaeDecoderStep` is used to decode the latents to images"
|
||||
)
|
||||
@@ -1,88 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22DenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanTextEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. DENOISE
|
||||
# ====================
|
||||
|
||||
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
|
||||
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22DenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. BLOCKS (Wan2.2 text2video)
|
||||
# ====================
|
||||
|
||||
|
||||
class Wan22Blocks(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
Wan22CoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ " - `WanTextEncoderStep` encodes the text\n"
|
||||
+ " - `Wan22CoreDenoiseStep` denoes the latents\n"
|
||||
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanImageResizeStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
|
||||
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. BLOCKS (Wan2.2 Image2Video)
|
||||
# ====================
|
||||
|
||||
|
||||
class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanImage2VideoVaeEncoderStep,
|
||||
Wan22Image2VideoCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline for image-to-video using Wan2.2.\n"
|
||||
+ " - `WanTextEncoderStep` encodes the text\n"
|
||||
+ " - `WanImage2VideoVaeEncoderStep` encodes the image\n"
|
||||
+ " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n"
|
||||
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
|
||||
)
|
||||
@@ -1,203 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
WanImage2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanFirstLastFrameImageEncoderStep,
|
||||
WanFirstLastFrameVaeEncoderStep,
|
||||
WanImageCropResizeStep,
|
||||
WanImageEncoderStep,
|
||||
WanImageResizeStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# ====================
|
||||
# 1. IMAGE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V (first frame only)
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
|
||||
|
||||
|
||||
# wan2.1 FLF2V (first and last frame)
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
|
||||
|
||||
|
||||
# wan2.1 Auto Image Encoder
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
model_name = "wan-i2v"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image Encoder step that encode the image to generate the image embeddings"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V (first frame only)
|
||||
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
|
||||
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
# wan2.1 FLF2V (first and last frame)
|
||||
class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanImageResizeStep,
|
||||
WanImageCropResizeStep,
|
||||
WanFirstLastFrameVaeEncoderStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
|
||||
|
||||
|
||||
# wan2.1 Auto Vae Encoder
|
||||
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V core denoise (support both I2V and FLF2V)
|
||||
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. BLOCKS (Wan2.1 Image2Video)
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 Image2Video Auto Blocks
|
||||
class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeEncoderStep,
|
||||
WanImage2VideoCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for image-to-video using Wan.\n"
|
||||
+ "- for I2V workflow, all you need to provide is `image`"
|
||||
+ "- for FLF2V workflow, all you need to provide is `last_image` and `image`"
|
||||
)
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from ...utils import logging
|
||||
@@ -28,12 +30,19 @@ class WanModularPipeline(
|
||||
WanLoraLoaderMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for Wan2.1 text2video.
|
||||
A ModularPipeline for Wan.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "WanBlocks"
|
||||
default_blocks_name = "WanAutoBlocks"
|
||||
|
||||
# override the default_blocks_name in base class, which is just return self.default_blocks_name
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22AutoBlocks"
|
||||
else:
|
||||
return "WanAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
@@ -109,33 +118,3 @@ class WanModularPipeline(
|
||||
if hasattr(self, "scheduler") and self.scheduler is not None:
|
||||
num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
return num_train_timesteps
|
||||
|
||||
|
||||
class WanImage2VideoModularPipeline(WanModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "WanImage2VideoAutoBlocks"
|
||||
|
||||
|
||||
class Wan22ModularPipeline(WanModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.2 text2video.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Wan22Blocks"
|
||||
|
||||
|
||||
class Wan22Image2VideoModularPipeline(Wan22ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.2 image2video.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Wan22Image2VideoBlocks"
|
||||
|
||||
@@ -248,7 +248,7 @@ AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
|
||||
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan-i2v", WanImageToVideoPipeline),
|
||||
("wan", WanImageToVideoPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -635,12 +635,10 @@ class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
control_image_input = control_image.repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
control_image_input = control_image
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
@@ -649,7 +647,7 @@ class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
control_image_input,
|
||||
control_image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -657,12 +657,10 @@ class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
control_image_input = control_image.repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
control_image_input = control_image
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
@@ -671,7 +669,7 @@ class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
control_image_input,
|
||||
control_image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -51,14 +51,7 @@ class DPMSolverSDESchedulerOutput(BaseOutput):
|
||||
class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t0: float,
|
||||
t1: float,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get("w0", torch.zeros_like(x))
|
||||
if seed is None:
|
||||
@@ -86,23 +79,10 @@ class BatchedBrownianTree:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def sort(a: float, b: float) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Sorts two float values and returns them along with a sign indicating if they were swapped.
|
||||
def sort(a, b):
|
||||
return (a, b, 1) if a < b else (b, a, -1)
|
||||
|
||||
Args:
|
||||
a (`float`):
|
||||
The first value.
|
||||
b (`float`):
|
||||
The second value.
|
||||
|
||||
Returns:
|
||||
`Tuple[float, float, float]`:
|
||||
A tuple containing the sorted values (min, max) and a sign (1.0 if a < b, -1.0 otherwise).
|
||||
"""
|
||||
return (a, b, 1.0) if a < b else (b, a, -1.0)
|
||||
|
||||
def __call__(self, t0: float, t1: float) -> torch.Tensor:
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
@@ -112,29 +92,23 @@ class BrownianTreeNoiseSampler:
|
||||
"""A noise sampler backed by a torchsde.BrownianTree.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): The tensor whose shape, device and dtype is used to generate random samples.
|
||||
sigma_min (`float`): The low end of the valid interval.
|
||||
sigma_max (`float`): The high end of the valid interval.
|
||||
seed (`int` or `List[int]`): The random seed. If a list of seeds is
|
||||
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
||||
random samples.
|
||||
sigma_min (float): The low end of the valid interval.
|
||||
sigma_max (float): The high end of the valid interval.
|
||||
seed (int or List[int]): The random seed. If a list of seeds is
|
||||
supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
|
||||
with its own seed.
|
||||
transform (`callable`): A function that maps sigma to the sampler's
|
||||
transform (callable): A function that maps sigma to the sampler's
|
||||
internal timestep.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma_min: float,
|
||||
sigma_max: float,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
transform: Callable[[float], float] = lambda x: x,
|
||||
):
|
||||
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
||||
self.transform = transform
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
||||
|
||||
def __call__(self, sigma: float, sigma_next: float) -> torch.Tensor:
|
||||
def __call__(self, sigma, sigma_next):
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
@@ -242,28 +216,19 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085, # sensible defaults
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
noise_sampler_seed: Optional[int] = None,
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
@@ -273,15 +238,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -348,7 +305,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index = self._begin_index
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self) -> torch.Tensor:
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||
return self.sigmas.max()
|
||||
@@ -356,21 +313,21 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self) -> Union[int, None]:
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Union[int, None]:
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -412,7 +369,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -421,8 +378,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
num_train_timesteps (`int`, *optional*):
|
||||
The number of train timesteps. If `None`, uses `self.config.num_train_timesteps`.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
@@ -488,7 +443,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.noise_sampler = None
|
||||
|
||||
def _second_order_timesteps(self, sigmas: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
def _second_order_timesteps(self, sigmas, log_sigmas):
|
||||
def sigma_fn(_t):
|
||||
return np.exp(-_t)
|
||||
|
||||
@@ -504,7 +459,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
return timesteps
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -649,14 +604,14 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
@property
|
||||
def state_in_first_order(self) -> bool:
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
model_output: Union[torch.Tensor, np.ndarray],
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
sample: Union[torch.Tensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
s_noise: float = 1.0,
|
||||
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
|
||||
@@ -665,11 +620,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
model_output (`torch.Tensor` or `np.ndarray`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor` or `np.ndarray`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
|
||||
@@ -688,9 +643,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Create a noise sampler if it hasn't been created yet
|
||||
if self.noise_sampler is None:
|
||||
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
|
||||
self.noise_sampler = BrownianTreeNoiseSampler(
|
||||
sample, min_sigma.item(), max_sigma.item(), self.noise_sampler_seed
|
||||
)
|
||||
self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
|
||||
|
||||
# Define functions to compute sigma and t from each other
|
||||
def sigma_fn(_t: torch.Tensor) -> torch.Tensor:
|
||||
@@ -741,10 +694,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
sigma_from = sigma_fn(t)
|
||||
sigma_to = sigma_fn(t_next)
|
||||
sigma_up = min(
|
||||
sigma_to,
|
||||
(sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
|
||||
)
|
||||
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
ancestral_t = t_fn(sigma_down)
|
||||
prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - (
|
||||
@@ -821,5 +771,5 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -47,21 +47,6 @@ class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinBaseModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -302,7 +287,7 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22Blocks(metaclass=DummyObject):
|
||||
class Wan22AutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -317,82 +302,7 @@ class Wan22Blocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22Image2VideoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22Image2VideoModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanImage2VideoAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanImage2VideoModularPipeline(metaclass=DummyObject):
|
||||
class WanAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user