mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 22:44:38 +08:00
Compare commits
3 Commits
dynamic-te
...
wan-cache
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca5cfbd37e | ||
|
|
5ac65c4513 | ||
|
|
8c2b2cdc52 |
@@ -26,6 +26,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
|
|||||||
from ...utils.torch_utils import maybe_allow_in_graph
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
from ..attention import FeedForward
|
from ..attention import FeedForward
|
||||||
from ..attention_processor import Attention
|
from ..attention_processor import Attention
|
||||||
|
from ..cache_utils import CacheMixin
|
||||||
from ..embeddings import PixArtAlphaTextProjection
|
from ..embeddings import PixArtAlphaTextProjection
|
||||||
from ..modeling_outputs import Transformer2DModelOutput
|
from ..modeling_outputs import Transformer2DModelOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
@@ -298,7 +299,7 @@ class LTXVideoTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
|
||||||
r"""
|
r"""
|
||||||
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
|||||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||||
from ..attention import FeedForward
|
from ..attention import FeedForward
|
||||||
from ..attention_processor import Attention
|
from ..attention_processor import Attention
|
||||||
|
from ..cache_utils import CacheMixin
|
||||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||||
from ..modeling_outputs import Transformer2DModelOutput
|
from ..modeling_outputs import Transformer2DModelOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
@@ -288,7 +289,7 @@ class WanTransformerBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||||
r"""
|
r"""
|
||||||
A Transformer model for video-like data used in the Wan model.
|
A Transformer model for video-like data used in the Wan model.
|
||||||
|
|
||||||
|
|||||||
@@ -489,6 +489,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
|||||||
def num_timesteps(self):
|
def num_timesteps(self):
|
||||||
return self._num_timesteps
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_timestep(self):
|
||||||
|
return self._current_timestep
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attention_kwargs(self):
|
def attention_kwargs(self):
|
||||||
return self._attention_kwargs
|
return self._attention_kwargs
|
||||||
@@ -622,6 +626,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
|||||||
self._guidance_scale = guidance_scale
|
self._guidance_scale = guidance_scale
|
||||||
self._attention_kwargs = attention_kwargs
|
self._attention_kwargs = attention_kwargs
|
||||||
self._interrupt = False
|
self._interrupt = False
|
||||||
|
self._current_timestep = None
|
||||||
|
|
||||||
# 2. Define call parameters
|
# 2. Define call parameters
|
||||||
if prompt is not None and isinstance(prompt, str):
|
if prompt is not None and isinstance(prompt, str):
|
||||||
@@ -706,6 +711,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
|||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
self._current_timestep = t
|
||||||
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -774,6 +774,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
|||||||
def num_timesteps(self):
|
def num_timesteps(self):
|
||||||
return self._num_timesteps
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_timestep(self):
|
||||||
|
return self._current_timestep
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attention_kwargs(self):
|
def attention_kwargs(self):
|
||||||
return self._attention_kwargs
|
return self._attention_kwargs
|
||||||
@@ -933,6 +937,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
|||||||
self._guidance_scale = guidance_scale
|
self._guidance_scale = guidance_scale
|
||||||
self._attention_kwargs = attention_kwargs
|
self._attention_kwargs = attention_kwargs
|
||||||
self._interrupt = False
|
self._interrupt = False
|
||||||
|
self._current_timestep = None
|
||||||
|
|
||||||
# 2. Define call parameters
|
# 2. Define call parameters
|
||||||
if prompt is not None and isinstance(prompt, str):
|
if prompt is not None and isinstance(prompt, str):
|
||||||
@@ -1066,6 +1071,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
|||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
self._current_timestep = t
|
||||||
|
|
||||||
if image_cond_noise_scale > 0:
|
if image_cond_noise_scale > 0:
|
||||||
# Add timestep-dependent noise to the hard-conditioning latents
|
# Add timestep-dependent noise to the hard-conditioning latents
|
||||||
# This helps with motion continuity, especially when conditioned on a single frame
|
# This helps with motion continuity, especially when conditioned on a single frame
|
||||||
|
|||||||
@@ -550,6 +550,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
|||||||
def num_timesteps(self):
|
def num_timesteps(self):
|
||||||
return self._num_timesteps
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_timestep(self):
|
||||||
|
return self._current_timestep
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attention_kwargs(self):
|
def attention_kwargs(self):
|
||||||
return self._attention_kwargs
|
return self._attention_kwargs
|
||||||
@@ -686,6 +690,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
|||||||
self._guidance_scale = guidance_scale
|
self._guidance_scale = guidance_scale
|
||||||
self._attention_kwargs = attention_kwargs
|
self._attention_kwargs = attention_kwargs
|
||||||
self._interrupt = False
|
self._interrupt = False
|
||||||
|
self._current_timestep = None
|
||||||
|
|
||||||
# 2. Define call parameters
|
# 2. Define call parameters
|
||||||
if prompt is not None and isinstance(prompt, str):
|
if prompt is not None and isinstance(prompt, str):
|
||||||
@@ -778,6 +783,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
|||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
self._current_timestep = t
|
||||||
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user