Compare commits

...

4 Commits

Author SHA1 Message Date
yiyixuxu
100166ed53 update 2024-08-31 10:37:18 +02:00
yiyixuxu
16704379a0 add record function 2024-08-31 04:39:44 +02:00
yiyixuxu
4bd87a1fe9 fix 2024-08-29 22:35:31 +02:00
yiyixuxu
484443e0b4 put aragne on device 2024-08-29 22:28:51 +02:00
4 changed files with 135 additions and 118 deletions

View File

@@ -550,8 +550,11 @@ def get_1d_rotary_pos_embed(
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
freqs = freqs.to(pos.device)
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox

View File

@@ -38,6 +38,7 @@ from ..modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from torch.profiler import record_function
@maybe_allow_in_graph
@@ -439,8 +440,10 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
with record_function(" x_embedder"):
hidden_states = self.x_embedder(hidden_states)
with record_function(" time_text_embed"):
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
@@ -453,6 +456,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
with record_function(" pos_embeds (rotary)"):
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
@@ -468,6 +472,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
with record_function(" blocks"):
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
@@ -506,6 +511,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
with record_function(" single blocks"):
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:

View File

@@ -36,6 +36,8 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import FluxPipelineOutput
from torch.profiler import record_function
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -716,6 +718,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with record_function(f"transformer iter_{i}"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
@@ -730,6 +734,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
with record_function(f"scheduler.step (iter_{i})"):
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
@@ -757,6 +762,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
image = latents
else:
with record_function(f"decode latent"):
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -22,7 +22,7 @@ import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .scheduling_utils import SchedulerMixin
from torch.profiler import record_function
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -284,16 +284,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
" one of the `scheduler.timesteps` as a timestep."
),
)
with record_function(" _init_step_index"):
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
with record_function(" get sigma and sigma_next"):
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
with record_function(" get prev_sample"):
prev_sample = sample + (sigma_next - sigma) * model_output
# Cast sample back to model compatible dtype