Compare commits

..

1 Commits

Author SHA1 Message Date
DN6
b59c02fb66 update 2025-01-14 13:55:48 +05:30
15 changed files with 26 additions and 268 deletions

View File

@@ -158,9 +158,6 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
if args.enable_vae_tiling:
pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)
pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
@@ -600,7 +597,6 @@ def parse_args(input_args=None):
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
if input_args is not None:
args = parser.parse_args(input_args)

View File

@@ -74,9 +74,8 @@ To create the package for PyPI.
twine upload dist/* -r pypi
10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory. You can use the following
Space to fetch all the commits applicable for the release: https://huggingface.co/spaces/sayakpaul/auto-release-notes-diffusers.
It automatically fetches the correct tag and branch but also provides the option to configure them.
`tag` should be the previous release tag (v0.26.1, for example), and `branch` should be
Space to fetch all the commits applicable for the release: https://huggingface.co/spaces/lysandre/github-release. Repo should
be `huggingface/diffusers`. `tag` should be the previous release tag (v0.26.1, for example), and `branch` should be
the latest release branch (v0.27.0-release, for example). It denotes all commits that have happened on branch
v0.27.0-release after the tag v0.26.1 was created.

View File

@@ -21,7 +21,6 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_peft_available,
is_peft_version,
is_torch_version,
@@ -1982,17 +1981,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
in_features = state_dict[lora_A_weight_name].shape[1]
out_features = state_dict[lora_B_weight_name].shape[0]
# Model maybe loaded with different quantization schemes which may flatten the params.
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
# preserve weight shape.
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
# This means there's no need for an expansion in the params, so we simply skip.
if tuple(module_weight_shape) == (out_features, in_features):
if tuple(module_weight.shape) == (out_features, in_features):
continue
# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
debug_message = ""
if in_features > module_in_features:
@@ -2088,16 +2080,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
if base_module_shape[1] > lora_A_param.shape[1]:
if base_weight_param.shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_module_shape[1] < lora_A_param.shape[1]:
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
@@ -2109,28 +2098,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
return lora_state_dict
@staticmethod
def _calculate_module_shape(
model: "torch.nn.Module",
base_module: "torch.nn.Linear" = None,
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
if base_module is not None:
return _get_weight_shape(base_module.weight)
elif base_weight_param_name is not None:
if not base_weight_param_name.endswith(".weight"):
raise ValueError(
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
)
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
submodule = get_submodule_by_name(model, module_path)
return _get_weight_shape(submodule.weight)
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

View File

@@ -40,7 +40,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
hf_token = kwargs.pop("hf_token", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
@@ -73,7 +73,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=hf_token,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -93,7 +93,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=hf_token,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -312,7 +312,7 @@ class TextualInversionLoaderMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
hf_token (`str` or *bool*, *optional*):
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):

View File

@@ -21,7 +21,7 @@ from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import Mochi1LoraLoaderMixin
from ...models.autoencoders import AutoencoderKLMochi
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
@@ -151,8 +151,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLMochi`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
@@ -171,7 +171,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLMochi,
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel,

View File

@@ -16,7 +16,6 @@ import html
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -42,7 +41,6 @@ from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_1024_BIN,
)
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN
from .pag_utils import PAGMixin
@@ -641,7 +639,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
clean_caption: bool = False,
clean_caption: bool = True,
use_resolution_binning: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -757,9 +755,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if use_resolution_binning:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
elif self.transformer.config.sample_size == 64:
if self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
@@ -916,14 +912,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)

View File

@@ -16,7 +16,6 @@ import html
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -954,14 +953,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)

View File

@@ -13,21 +13,19 @@
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -164,7 +162,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -196,14 +194,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`PreTrainedModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*):
Image processor for IP Adapter.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -217,8 +211,6 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
):
super().__init__()
@@ -232,8 +224,6 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3=tokenizer_3,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
@@ -828,10 +818,6 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@@ -840,84 +826,6 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
def interrupt(self):
return self._interrupt
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
"""Encodes the given image into a feature representation using a pre-trained image encoder.
Args:
image (`PipelineImageInput`):
Input image to be encoded.
device: (`torch.device`):
Torch device.
Returns:
`torch.Tensor`: The encoded image feature representation.
"""
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=self.dtype)
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> torch.Tensor:
"""Prepares image embeddings for use in the IP-Adapter.
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
Args:
ip_adapter_image (`PipelineImageInput`, *optional*):
The input image to extract features from for IP-Adapter.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Precomputed image embeddings.
device: (`torch.device`, *optional*):
Torch device.
num_images_per_prompt (`int`, defaults to 1):
Number of images that should be generated per prompt.
do_classifier_free_guidance (`bool`, defaults to True):
Whether to use classifier free guidance or not.
"""
device = device or self._execution_device
if ip_adapter_image_embeds is not None:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
else:
single_image_embeds = ip_adapter_image_embeds
elif ip_adapter_image is not None:
single_image_embeds = self.encode_image(ip_adapter_image, device)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
else:
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
return image_embeds.to(device=device)
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, *args, **kwargs):
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
logger.warning(
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
)
super().enable_sequential_cpu_offload(*args, **kwargs)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -945,11 +853,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -985,9 +890,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
latents tensor will ge generated by `mask_image`.
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
@@ -1048,22 +953,12 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -1111,7 +1006,6 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
@@ -1266,22 +1160,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
)
# 7. Prepare image embeddings
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
else:
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
# 8. Denoising loop
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1302,7 +1181,6 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

View File

@@ -101,7 +101,7 @@ from .import_utils import (
is_xformers_available,
requires_backends,
)
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
from .loading_utils import get_module_from_name, load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (

View File

@@ -148,15 +148,3 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
module = new_module
tensor_name = splits[-1]
return module, tensor_name
def get_submodule_by_name(root_module, module_path: str):
current = root_module
parts = module_path.split(".")
for part in parts:
if part.isdigit():
idx = int(part)
current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
else:
current = getattr(current, part)
return current

View File

@@ -33,7 +33,6 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.8]
@property
def dummy_input(self):

View File

@@ -33,7 +33,6 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView3PlusTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.6, 0.6]
@property
def dummy_input(self):

View File

@@ -106,8 +106,6 @@ class StableDiffusion3InpaintPipelineFastTests(PipelineLatentTesterMixin, unitte
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):

View File

@@ -20,7 +20,6 @@ import unittest
import numpy as np
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import is_accelerate_version, logging
@@ -569,27 +568,6 @@ class SlowBnb4BitFluxTests(Base4bitTests):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125)
output = self.pipeline_4bit(
prompt=self.prompt,
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
@slow
class BaseBnb4BitSerializationTests(Base4bitTests):

View File

@@ -18,7 +18,6 @@ import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers.utils import is_accelerate_version
@@ -31,7 +30,6 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_accelerate,
require_bitsandbytes_version_greater,
require_peft_version_greater,
require_torch,
require_torch_gpu,
require_transformers_version_greater,
@@ -511,29 +509,6 @@ class SlowBnb8bitFluxTests(Base8bitTests):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
@require_peft_version_greater("0.14.0")
def test_lora_loading(self):
self.pipeline_8bit.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
self.pipeline_8bit.set_adapters("hyper-sd", adapter_weights=0.125)
output = self.pipeline_8bit(
prompt=self.prompt,
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3916, 0.3916, 0.3887, 0.4243, 0.4155, 0.4233, 0.4570, 0.4531, 0.4248])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
@slow
class BaseBnb8bitSerializationTests(Base8bitTests):