mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-22 20:34:51 +08:00
Compare commits
4 Commits
sage-kerne
...
rope-init-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
100166ed53 | ||
|
|
16704379a0 | ||
|
|
4bd87a1fe9 | ||
|
|
484443e0b4 |
@@ -550,8 +550,11 @@ def get_1d_rotary_pos_embed(
|
|||||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||||
|
|
||||||
theta = theta * ntk_factor
|
theta = theta * ntk_factor
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
|
freqs = (
|
||||||
freqs = freqs.to(pos.device)
|
1.0
|
||||||
|
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||||
|
/ linear_factor
|
||||||
|
) # [D/2]
|
||||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||||
if use_real and repeat_interleave_real:
|
if use_real and repeat_interleave_real:
|
||||||
# flux, hunyuan-dit, cogvideox
|
# flux, hunyuan-dit, cogvideox
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from ..modeling_outputs import Transformer2DModelOutput
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
from torch.profiler import record_function
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
@@ -439,8 +440,10 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||||
)
|
)
|
||||||
|
with record_function(" x_embedder"):
|
||||||
hidden_states = self.x_embedder(hidden_states)
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
|
|
||||||
|
with record_function(" time_text_embed"):
|
||||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||||
@@ -453,6 +456,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|||||||
)
|
)
|
||||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||||
|
|
||||||
|
with record_function(" pos_embeds (rotary)"):
|
||||||
if txt_ids.ndim == 3:
|
if txt_ids.ndim == 3:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||||
@@ -468,6 +472,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|||||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||||
image_rotary_emb = self.pos_embed(ids)
|
image_rotary_emb = self.pos_embed(ids)
|
||||||
|
|
||||||
|
with record_function(" blocks"):
|
||||||
for index_block, block in enumerate(self.transformer_blocks):
|
for index_block, block in enumerate(self.transformer_blocks):
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
@@ -506,6 +511,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|||||||
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||||
|
|
||||||
|
with record_function(" single blocks"):
|
||||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ from ...utils.torch_utils import randn_tensor
|
|||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from .pipeline_output import FluxPipelineOutput
|
from .pipeline_output import FluxPipelineOutput
|
||||||
|
|
||||||
|
from torch.profiler import record_function
|
||||||
|
|
||||||
|
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
@@ -716,6 +718,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
|
with record_function(f"transformer iter_{i}"):
|
||||||
|
|
||||||
noise_pred = self.transformer(
|
noise_pred = self.transformer(
|
||||||
hidden_states=latents,
|
hidden_states=latents,
|
||||||
timestep=timestep / 1000,
|
timestep=timestep / 1000,
|
||||||
@@ -730,6 +734,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents_dtype = latents.dtype
|
latents_dtype = latents.dtype
|
||||||
|
with record_function(f"scheduler.step (iter_{i})"):
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
if latents.dtype != latents_dtype:
|
if latents.dtype != latents_dtype:
|
||||||
@@ -757,6 +762,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|||||||
image = latents
|
image = latents
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
with record_function(f"decode latent"):
|
||||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
image = self.vae.decode(latents, return_dict=False)[0]
|
image = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import torch
|
|||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput, logging
|
from ..utils import BaseOutput, logging
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
from torch.profiler import record_function
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -284,16 +284,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
" one of the `scheduler.timesteps` as a timestep."
|
" one of the `scheduler.timesteps` as a timestep."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
with record_function(" _init_step_index"):
|
||||||
if self.step_index is None:
|
if self.step_index is None:
|
||||||
self._init_step_index(timestep)
|
self._init_step_index(timestep)
|
||||||
|
|
||||||
# Upcast to avoid precision issues when computing prev_sample
|
# Upcast to avoid precision issues when computing prev_sample
|
||||||
sample = sample.to(torch.float32)
|
sample = sample.to(torch.float32)
|
||||||
|
|
||||||
|
with record_function(" get sigma and sigma_next"):
|
||||||
sigma = self.sigmas[self.step_index]
|
sigma = self.sigmas[self.step_index]
|
||||||
sigma_next = self.sigmas[self.step_index + 1]
|
sigma_next = self.sigmas[self.step_index + 1]
|
||||||
|
|
||||||
|
with record_function(" get prev_sample"):
|
||||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||||
|
|
||||||
# Cast sample back to model compatible dtype
|
# Cast sample back to model compatible dtype
|
||||||
|
|||||||
Reference in New Issue
Block a user