mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 02:44:53 +08:00
Compare commits
4 Commits
style-fixe
...
rope-init-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
100166ed53 | ||
|
|
16704379a0 | ||
|
|
4bd87a1fe9 | ||
|
|
484443e0b4 |
@@ -550,8 +550,11 @@ def get_1d_rotary_pos_embed(
|
|||||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||||
|
|
||||||
theta = theta * ntk_factor
|
theta = theta * ntk_factor
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
|
freqs = (
|
||||||
freqs = freqs.to(pos.device)
|
1.0
|
||||||
|
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||||
|
/ linear_factor
|
||||||
|
) # [D/2]
|
||||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||||
if use_real and repeat_interleave_real:
|
if use_real and repeat_interleave_real:
|
||||||
# flux, hunyuan-dit, cogvideox
|
# flux, hunyuan-dit, cogvideox
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from ..modeling_outputs import Transformer2DModelOutput
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
from torch.profiler import record_function
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
@@ -439,109 +440,114 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||||
)
|
)
|
||||||
hidden_states = self.x_embedder(hidden_states)
|
with record_function(" x_embedder"):
|
||||||
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
|
||||||
if guidance is not None:
|
|
||||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
|
||||||
else:
|
|
||||||
guidance = None
|
|
||||||
temb = (
|
|
||||||
self.time_text_embed(timestep, pooled_projections)
|
|
||||||
if guidance is None
|
|
||||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
|
||||||
)
|
|
||||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
|
||||||
|
|
||||||
if txt_ids.ndim == 3:
|
|
||||||
logger.warning(
|
|
||||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
|
||||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
||||||
)
|
|
||||||
txt_ids = txt_ids[0]
|
|
||||||
if img_ids.ndim == 3:
|
|
||||||
logger.warning(
|
|
||||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
|
||||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
||||||
)
|
|
||||||
img_ids = img_ids[0]
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
|
||||||
image_rotary_emb = self.pos_embed(ids)
|
|
||||||
|
|
||||||
for index_block, block in enumerate(self.transformer_blocks):
|
|
||||||
if self.training and self.gradient_checkpointing:
|
|
||||||
|
|
||||||
def create_custom_forward(module, return_dict=None):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
if return_dict is not None:
|
|
||||||
return module(*inputs, return_dict=return_dict)
|
|
||||||
else:
|
|
||||||
return module(*inputs)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
||||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states,
|
|
||||||
temb,
|
|
||||||
image_rotary_emb,
|
|
||||||
**ckpt_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
with record_function(" time_text_embed"):
|
||||||
|
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states, hidden_states = block(
|
guidance = None
|
||||||
hidden_states=hidden_states,
|
temb = (
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
self.time_text_embed(timestep, pooled_projections)
|
||||||
temb=temb,
|
if guidance is None
|
||||||
image_rotary_emb=image_rotary_emb,
|
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||||
)
|
)
|
||||||
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||||
|
|
||||||
# controlnet residual
|
with record_function(" pos_embeds (rotary)"):
|
||||||
if controlnet_block_samples is not None:
|
if txt_ids.ndim == 3:
|
||||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
logger.warning(
|
||||||
interval_control = int(np.ceil(interval_control))
|
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||||
|
)
|
||||||
|
txt_ids = txt_ids[0]
|
||||||
|
if img_ids.ndim == 3:
|
||||||
|
logger.warning(
|
||||||
|
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||||
|
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||||
|
)
|
||||||
|
img_ids = img_ids[0]
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||||
|
image_rotary_emb = self.pos_embed(ids)
|
||||||
|
|
||||||
|
with record_function(" blocks"):
|
||||||
|
for index_block, block in enumerate(self.transformer_blocks):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states,
|
||||||
|
temb,
|
||||||
|
image_rotary_emb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
encoder_hidden_states, hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
# controlnet residual
|
||||||
|
if controlnet_block_samples is not None:
|
||||||
|
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||||
|
interval_control = int(np.ceil(interval_control))
|
||||||
|
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||||
|
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||||
|
|
||||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
with record_function(" single blocks"):
|
||||||
if self.training and self.gradient_checkpointing:
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
def create_custom_forward(module, return_dict=None):
|
def create_custom_forward(module, return_dict=None):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
if return_dict is not None:
|
if return_dict is not None:
|
||||||
return module(*inputs, return_dict=return_dict)
|
return module(*inputs, return_dict=return_dict)
|
||||||
else:
|
else:
|
||||||
return module(*inputs)
|
return module(*inputs)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(block),
|
create_custom_forward(block),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
temb,
|
temb,
|
||||||
image_rotary_emb,
|
image_rotary_emb,
|
||||||
**ckpt_kwargs,
|
**ckpt_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
hidden_states = block(
|
hidden_states = block(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
)
|
)
|
||||||
|
|
||||||
# controlnet residual
|
# controlnet residual
|
||||||
if controlnet_single_block_samples is not None:
|
if controlnet_single_block_samples is not None:
|
||||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||||
interval_control = int(np.ceil(interval_control))
|
interval_control = int(np.ceil(interval_control))
|
||||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||||
+ controlnet_single_block_samples[index_block // interval_control]
|
+ controlnet_single_block_samples[index_block // interval_control]
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ from ...utils.torch_utils import randn_tensor
|
|||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from .pipeline_output import FluxPipelineOutput
|
from .pipeline_output import FluxPipelineOutput
|
||||||
|
|
||||||
|
from torch.profiler import record_function
|
||||||
|
|
||||||
|
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
@@ -716,21 +718,24 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
noise_pred = self.transformer(
|
with record_function(f"transformer iter_{i}"):
|
||||||
hidden_states=latents,
|
|
||||||
timestep=timestep / 1000,
|
noise_pred = self.transformer(
|
||||||
guidance=guidance,
|
hidden_states=latents,
|
||||||
pooled_projections=pooled_prompt_embeds,
|
timestep=timestep / 1000,
|
||||||
encoder_hidden_states=prompt_embeds,
|
guidance=guidance,
|
||||||
txt_ids=text_ids,
|
pooled_projections=pooled_prompt_embeds,
|
||||||
img_ids=latent_image_ids,
|
encoder_hidden_states=prompt_embeds,
|
||||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
txt_ids=text_ids,
|
||||||
return_dict=False,
|
img_ids=latent_image_ids,
|
||||||
)[0]
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents_dtype = latents.dtype
|
latents_dtype = latents.dtype
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
with record_function(f"scheduler.step (iter_{i})"):
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
if latents.dtype != latents_dtype:
|
if latents.dtype != latents_dtype:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@@ -757,10 +762,11 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|||||||
image = latents
|
image = latents
|
||||||
|
|
||||||
else:
|
else:
|
||||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
with record_function(f"decode latent"):
|
||||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||||
image = self.vae.decode(latents, return_dict=False)[0]
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
image = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
# Offload all models
|
# Offload all models
|
||||||
self.maybe_free_model_hooks()
|
self.maybe_free_model_hooks()
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import torch
|
|||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput, logging
|
from ..utils import BaseOutput, logging
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
from torch.profiler import record_function
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -284,20 +284,22 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
" one of the `scheduler.timesteps` as a timestep."
|
" one of the `scheduler.timesteps` as a timestep."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
with record_function(" _init_step_index"):
|
||||||
if self.step_index is None:
|
if self.step_index is None:
|
||||||
self._init_step_index(timestep)
|
self._init_step_index(timestep)
|
||||||
|
|
||||||
# Upcast to avoid precision issues when computing prev_sample
|
# Upcast to avoid precision issues when computing prev_sample
|
||||||
sample = sample.to(torch.float32)
|
sample = sample.to(torch.float32)
|
||||||
|
|
||||||
sigma = self.sigmas[self.step_index]
|
with record_function(" get sigma and sigma_next"):
|
||||||
sigma_next = self.sigmas[self.step_index + 1]
|
sigma = self.sigmas[self.step_index]
|
||||||
|
sigma_next = self.sigmas[self.step_index + 1]
|
||||||
|
|
||||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
with record_function(" get prev_sample"):
|
||||||
|
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||||
|
|
||||||
# Cast sample back to model compatible dtype
|
# Cast sample back to model compatible dtype
|
||||||
prev_sample = prev_sample.to(model_output.dtype)
|
prev_sample = prev_sample.to(model_output.dtype)
|
||||||
|
|
||||||
# upon completion increase step index by one
|
# upon completion increase step index by one
|
||||||
self._step_index += 1
|
self._step_index += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user