mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-12 06:35:38 +08:00
Compare commits
7 Commits
modular-wa
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b86bd99eac | ||
|
|
5b202111bf | ||
|
|
4ac2b4a521 | ||
|
|
418313bbf6 | ||
|
|
2120c3096f | ||
|
|
ed6e5ecf67 | ||
|
|
d44b5f86e6 |
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
### FirstBlockCacheConfig
|
||||
## FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
|
||||
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
|
||||
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
|
||||
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
> [!TIP]
|
||||
@@ -62,6 +63,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
|
||||
|
||||
## LTX2LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
@@ -21,7 +21,7 @@ The abstract from the paper is:
|
||||
|
||||
*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
|
||||
|
||||
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
|
||||
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html).
|
||||
|
||||
This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️
|
||||
|
||||
|
||||
@@ -14,6 +14,10 @@
|
||||
|
||||
# LTX-2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
@@ -140,7 +140,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
@@ -256,7 +256,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
|
||||
@@ -53,7 +53,7 @@ The loop wrapper can pass additional arguments, like current iteration index, to
|
||||
|
||||
A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
|
||||
|
||||
- It recieves the iteration variable from the loop wrapper.
|
||||
- It receives the iteration variable from the loop wrapper.
|
||||
- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
|
||||
- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].
|
||||
|
||||
|
||||
@@ -68,6 +68,20 @@ config = FasterCacheConfig(
|
||||
pipeline.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## FirstBlockCache
|
||||
|
||||
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16
|
||||
)
|
||||
apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))
|
||||
```
|
||||
## TaylorSeer Cache
|
||||
|
||||
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
|
||||
@@ -87,8 +101,7 @@ from diffusers import FluxPipeline, TaylorSeerCacheConfig
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.to("cuda")
|
||||
).to("cuda")
|
||||
|
||||
config = TaylorSeerCacheConfig(
|
||||
cache_interval=5,
|
||||
@@ -97,4 +110,4 @@ config = TaylorSeerCacheConfig(
|
||||
taylor_factors_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
```
|
||||
|
||||
@@ -149,13 +149,13 @@ def get_args():
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_images",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
||||
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt_separator",
|
||||
|
||||
@@ -140,7 +140,7 @@ def get_args():
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt_separator",
|
||||
|
||||
@@ -1228,7 +1228,7 @@ def main(args):
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
|
||||
@@ -1178,7 +1178,7 @@ def main(args):
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode
|
||||
|
||||
___Note___:
|
||||
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
|
||||
NUM_DEVICES = jax.device_count()
|
||||
|
||||
# 1. Let's start by downloading the model and loading it into our pipeline class
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned separately and
|
||||
# will have to be passed to the pipeline during inference
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
|
||||
@@ -67,6 +67,7 @@ if is_torch_available():
|
||||
"SD3LoraLoaderMixin",
|
||||
"AuraFlowLoraLoaderMixin",
|
||||
"StableDiffusionXLLoraLoaderMixin",
|
||||
"LTX2LoraLoaderMixin",
|
||||
"LTXVideoLoraLoaderMixin",
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
@@ -121,6 +122,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
KandinskyLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTX2LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
|
||||
@@ -2140,6 +2140,54 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
||||
# Remove the prefix
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")}
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
|
||||
if non_diffusers_prefix == "diffusion_model":
|
||||
rename_dict = {
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
else:
|
||||
rename_dict = {"aggregate_embed": "text_proj_in"}
|
||||
|
||||
# Apply renaming
|
||||
renamed_state_dict = {}
|
||||
for key, value in converted_state_dict.items():
|
||||
new_key = key[:]
|
||||
for old_pattern, new_pattern in rename_dict.items():
|
||||
new_key = new_key.replace(old_pattern, new_pattern)
|
||||
renamed_state_dict[new_key] = value
|
||||
|
||||
# Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed
|
||||
final_state_dict = {}
|
||||
for key, value in renamed_state_dict.items():
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
final_state_dict[new_key] = value
|
||||
elif key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
final_state_dict[new_key] = value
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
|
||||
# Add transformer prefix
|
||||
prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors"
|
||||
final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()}
|
||||
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if has_diffusion_model:
|
||||
|
||||
@@ -48,6 +48,7 @@ from .lora_conversion_utils import (
|
||||
_convert_non_diffusers_flux2_lora_to_diffusers,
|
||||
_convert_non_diffusers_hidream_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltx2_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_qwen_lora_to_diffusers,
|
||||
@@ -74,6 +75,7 @@ logger = logging.get_logger(__name__)
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
TRANSFORMER_NAME = "transformer"
|
||||
LTX2_CONNECTOR_NAME = "connectors"
|
||||
|
||||
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
|
||||
|
||||
@@ -3011,6 +3013,233 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LTX2LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer", "connectors"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
connectors_name = LTX2_CONNECTOR_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
final_state_dict = state_dict
|
||||
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict)
|
||||
if is_non_diffusers_format:
|
||||
final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict)
|
||||
if has_connector:
|
||||
connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(
|
||||
state_dict, "text_embedding_projection"
|
||||
)
|
||||
final_state_dict.update(connectors_state_dict)
|
||||
out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict
|
||||
return out
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_peft_state_dict = {
|
||||
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
|
||||
}
|
||||
connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")}
|
||||
self.load_lora_into_transformer(
|
||||
transformer_peft_state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if connectors_peft_state_dict:
|
||||
self.load_lora_into_transformer(
|
||||
connectors_peft_state_dict,
|
||||
transformer=getattr(self, self.connectors_name)
|
||||
if not hasattr(self, "connectors")
|
||||
else self.connectors,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
prefix=self.connectors_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
prefix: str = "transformer",
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {prefix}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
|
||||
|
||||
@@ -67,6 +67,8 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTX2TextConnectors": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -102,14 +102,14 @@ def get_block(
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int, ...] = (),
|
||||
qkv_multiscales: Tuple[int, ...] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
|
||||
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_multiscales
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -247,7 +247,7 @@ class Encoder(nn.Module):
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
qkv_multiscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
@@ -339,7 +339,7 @@ class Decoder(nn.Module):
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
qkv_multiscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
|
||||
@@ -41,9 +41,11 @@ class CacheMixin:
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
||||
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
- [`~hooks.FasterCacheConfig`]
|
||||
- [`~hooks.FirstBlockCacheConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -455,7 +455,7 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
|
||||
return "Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -579,7 +579,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
|
||||
return "Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
|
||||
@@ -12,9 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
@@ -394,6 +399,14 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DECODE
|
||||
@@ -467,3 +480,9 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
]
|
||||
|
||||
@@ -12,11 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
@@ -307,6 +310,14 @@ class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
|
||||
" - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
@@ -334,3 +345,9 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
"- for edit (img2img) generation, you need to provide `image`\n"
|
||||
"- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -12,9 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageEditPlusRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
@@ -136,6 +141,14 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
@@ -179,3 +192,9 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
"- Each image is resized independently based on its own aspect ratio.\n"
|
||||
"- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -13,9 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageLayeredPrepareLatentsStep,
|
||||
QwenImageLayeredRoPEInputsStep,
|
||||
@@ -134,6 +139,14 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Layered img2img task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
@@ -157,3 +170,9 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
@@ -252,7 +253,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
"""
|
||||
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
|
||||
streams.
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -184,7 +184,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modifications by Decart AI Team:
|
||||
# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
|
||||
# - Based on pipeline_wan.py, but with supports receiving a condition video appended to the channel dimension.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
293
tests/lora/test_lora_layers_ltx2.py
Normal file
293
tests/lora/test_lora_layers_ltx2.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
|
||||
from ..testing_utils import floats_tensor, require_peft_backend
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = LTX2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"num_layers": 1,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"caption_channels": 32,
|
||||
"rope_double_precision": False,
|
||||
"rope_type": "split",
|
||||
}
|
||||
transformer_cls = LTX2VideoTransformer3DModel
|
||||
|
||||
vae_kwargs = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"block_out_channels": (8,),
|
||||
"decoder_block_out_channels": (8,),
|
||||
"layers_per_block": (1,),
|
||||
"decoder_layers_per_block": (1, 1),
|
||||
"spatio_temporal_scaling": (True,),
|
||||
"decoder_spatio_temporal_scaling": (True,),
|
||||
"decoder_inject_noise": (False, False),
|
||||
"downsample_type": ("spatial",),
|
||||
"upsample_residual": (False,),
|
||||
"upsample_factor": (1,),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
vae_cls = AutoencoderKLLTX2Video
|
||||
|
||||
audio_vae_kwargs = {
|
||||
"base_channels": 4,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1,),
|
||||
"num_res_blocks": 1,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 32,
|
||||
"latent_channels": 2,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 8,
|
||||
}
|
||||
audio_vae_cls = AutoencoderKLLTX2Audio
|
||||
|
||||
vocoder_kwargs = {
|
||||
"in_channels": 16, # output_channels * mel_bins = 2 * 8
|
||||
"hidden_channels": 32,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [4, 4],
|
||||
"upsample_factors": [2, 2],
|
||||
"resnet_kernel_sizes": [3],
|
||||
"resnet_dilations": [[1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 16000,
|
||||
}
|
||||
vocoder_cls = LTX2Vocoder
|
||||
|
||||
connectors_kwargs = {
|
||||
"caption_channels": 32, # Will be set dynamically from text_encoder
|
||||
"text_proj_in_factor": 2, # Will be set dynamically from text_encoder
|
||||
"video_connector_num_attention_heads": 4,
|
||||
"video_connector_attention_head_dim": 8,
|
||||
"video_connector_num_layers": 1,
|
||||
"video_connector_num_learnable_registers": None,
|
||||
"audio_connector_num_attention_heads": 4,
|
||||
"audio_connector_attention_head_dim": 8,
|
||||
"audio_connector_num_layers": 1,
|
||||
"audio_connector_num_learnable_registers": None,
|
||||
"connector_rope_base_seq_len": 32,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
}
|
||||
connectors_cls = LTX2TextConnectors
|
||||
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-gemma3"
|
||||
text_encoder_cls, text_encoder_id = (
|
||||
Gemma3ForConditionalGeneration,
|
||||
"hf-internal-testing/tiny-gemma3",
|
||||
)
|
||||
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 5, 32, 32, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
num_channels = 4
|
||||
num_frames = 5
|
||||
num_latent_frames = 2
|
||||
latent_height = 8
|
||||
latent_width = 8
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width))
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "a robot dancing",
|
||||
"num_frames": num_frames,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"frame_rate": 25.0,
|
||||
"max_sequence_length": sequence_length,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
# Override to instantiate LTX2-specific components (connectors, audio_vae, vocoder)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
|
||||
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
|
||||
|
||||
# Update caption_channels and text_proj_in_factor based on text_encoder config
|
||||
transformer_kwargs = self.transformer_kwargs.copy()
|
||||
transformer_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
|
||||
|
||||
connectors_kwargs = self.connectors_kwargs.copy()
|
||||
connectors_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
|
||||
connectors_kwargs["text_proj_in_factor"] = text_encoder.config.text_config.num_hidden_layers + 1
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = self.transformer_cls(**transformer_kwargs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = self.vae_cls(**self.vae_kwargs)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
audio_vae = self.audio_vae_cls(**self.audio_vae_kwargs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder = self.vocoder_cls(**self.vocoder_kwargs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
connectors = self.connectors_cls(**connectors_kwargs)
|
||||
|
||||
if scheduler_cls is None:
|
||||
scheduler_cls = self.scheduler_cls
|
||||
scheduler = scheduler_cls(**self.scheduler_kwargs)
|
||||
|
||||
rank = 4
|
||||
lora_alpha = rank if lora_alpha is None else lora_alpha
|
||||
|
||||
text_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=self.text_encoder_target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
|
||||
pipeline_components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"audio_vae": audio_vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
|
||||
return pipeline_components, text_lora_config, denoiser_lora_config
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user