mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
6 Commits
animatelcm
...
single-fil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
224a562b6b | ||
|
|
baf9924be7 | ||
|
|
d8d208acde | ||
|
|
e0f33dfca4 | ||
|
|
15b125bb0e | ||
|
|
12004bf3a7 |
@@ -444,7 +444,7 @@ export_to_gif(frames, "animatelcm.gif")
|
||||
A space rocket, 4K.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatelcm-output.gif"
|
||||
alt="masterpiece, bestquality, sunset"
|
||||
alt="A space rocket, 4K"
|
||||
style="width: 300px;" />
|
||||
</center></td>
|
||||
</tr>
|
||||
@@ -486,7 +486,7 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
|
||||
A space rocket, 4K.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatelcm-motion-lora.gif"
|
||||
alt="masterpiece, bestquality, sunset"
|
||||
alt="A space rocket, 4K"
|
||||
style="width: 300px;" />
|
||||
</center></td>
|
||||
</tr>
|
||||
|
||||
@@ -63,6 +63,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
|
||||
| InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
|
||||
| UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |
|
||||
| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -1707,6 +1708,111 @@ print("Latency of StableDiffusionPipeline--fp32",latency)
|
||||
|
||||
```
|
||||
|
||||
### Stable Diffusion XL on IPEX
|
||||
|
||||
This diffusion pipeline aims to accelarate the inference of Stable-Diffusion XL on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
|
||||
|
||||
To use this pipeline, you need to:
|
||||
1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
|
||||
|
||||
**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.0 to get the best performance.
|
||||
|
||||
|PyTorch Version|IPEX Version|
|
||||
|--|--|
|
||||
|[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)|
|
||||
|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
|
||||
|
||||
You can simply use pip to install IPEX with the latest version.
|
||||
```python
|
||||
python -m pip install intel_extension_for_pytorch
|
||||
```
|
||||
**Note:** To install a specific version, run with the following command:
|
||||
```
|
||||
python -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
|
||||
```
|
||||
|
||||
2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16.
|
||||
|
||||
**Note:** The values of `height` and `width` used during preparation with `prepare_for_ipex()` should be the same when running inference with the prepared pipeline.
|
||||
|
||||
```python
|
||||
pipe = StableDiffusionXLPipelineIpex.from_pretrained("stabilityai/sdxl-turbo", low_cpu_mem_usage=True, use_safetensors=True)
|
||||
# value of image height/width should be consistent with the pipeline inference
|
||||
# For Float32
|
||||
pipe.prepare_for_ipex(torch.float32, prompt, height=512, width=512)
|
||||
# For BFloat16
|
||||
pipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)
|
||||
```
|
||||
|
||||
Then you can use the ipex pipeline in a similar way to the default stable diffusion xl pipeline.
|
||||
```python
|
||||
# value of image height/width should be consistent with 'prepare_for_ipex()'
|
||||
# For Float32
|
||||
image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]
|
||||
# For BFloat16
|
||||
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
||||
image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]
|
||||
```
|
||||
|
||||
The following code compares the performance of the original stable diffusion xl pipeline with the ipex-optimized pipeline.
|
||||
By using this optimized pipeline, we can get about 1.4-2 times performance boost with BFloat16 on fourth generation of Intel Xeon CPUs,
|
||||
code-named Sapphire Rapids.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from pipeline_stable_diffusion_xl_ipex import StableDiffusionXLPipelineIpex
|
||||
import time
|
||||
|
||||
prompt = "sailing ship in storm by Rembrandt"
|
||||
model_id = "stabilityai/sdxl-turbo"
|
||||
steps = 4
|
||||
|
||||
# Helper function for time evaluation
|
||||
def elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):
|
||||
# warmup
|
||||
for _ in range(2):
|
||||
images = pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=0.0).images
|
||||
#time evaluation
|
||||
start = time.time()
|
||||
for _ in range(nb_pass):
|
||||
pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=0.0)
|
||||
end = time.time()
|
||||
return (end - start) / nb_pass
|
||||
|
||||
############## bf16 inference performance ###############
|
||||
|
||||
# 1. IPEX Pipeline initialization
|
||||
pipe = StableDiffusionXLPipelineIpex.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
|
||||
pipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)
|
||||
|
||||
# 2. Original Pipeline initialization
|
||||
pipe2 = StableDiffusionXLPipeline.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
|
||||
|
||||
# 3. Compare performance between Original Pipeline and IPEX Pipeline
|
||||
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
||||
latency = elapsed_time(pipe, num_inference_steps=steps)
|
||||
print("Latency of StableDiffusionXLPipelineIpex--bf16", latency, "s for total", steps, "steps")
|
||||
latency = elapsed_time(pipe2, num_inference_steps=steps)
|
||||
print("Latency of StableDiffusionXLPipeline--bf16", latency, "s for total", steps, "steps")
|
||||
|
||||
############## fp32 inference performance ###############
|
||||
|
||||
# 1. IPEX Pipeline initialization
|
||||
pipe3 = StableDiffusionXLPipelineIpex.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
|
||||
pipe3.prepare_for_ipex(torch.float32, prompt, height=512, width=512)
|
||||
|
||||
# 2. Original Pipeline initialization
|
||||
pipe4 = StableDiffusionXLPipeline.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
|
||||
|
||||
# 3. Compare performance between Original Pipeline and IPEX Pipeline
|
||||
latency = elapsed_time(pipe3, num_inference_steps=steps)
|
||||
print("Latency of StableDiffusionXLPipelineIpex--fp32", latency, "s for total", steps, "steps")
|
||||
latency = elapsed_time(pipe4, num_inference_steps=steps)
|
||||
print("Latency of StableDiffusionXLPipeline--fp32",latency, "s for total", steps, "steps")
|
||||
|
||||
```
|
||||
|
||||
### CLIP Guided Images Mixing With Stable Diffusion
|
||||
|
||||

|
||||
|
||||
1490
examples/community/pipeline_stable_diffusion_xl_ipex.py
Normal file
1490
examples/community/pipeline_stable_diffusion_xl_ipex.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -876,7 +876,7 @@ def create_diffusers_controlnet_model_from_ldm(
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
controlnet, diffusers_format_controlnet_checkpoint, torch_dtype=torch_dtype
|
||||
controlnet, diffusers_format_controlnet_checkpoint, dtype=torch_dtype
|
||||
)
|
||||
if controlnet._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in controlnet._keys_to_ignore_on_load_unexpected:
|
||||
|
||||
@@ -19,11 +19,22 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...loaders import (
|
||||
FromSingleFileMixin,
|
||||
IPAdapterMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
@@ -140,7 +151,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetInpaintPipeline(
|
||||
DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
|
||||
DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL.
|
||||
@@ -152,6 +163,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -195,6 +207,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -210,6 +224,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
|
||||
@@ -497,6 +513,66 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack(
|
||||
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
image_embeds = ip_adapter_image_embeds
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -566,6 +642,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
ip_adapter_image=None,
|
||||
ip_adapter_image_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
controlnet_conditioning_scale=1.0,
|
||||
@@ -752,6 +830,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
)
|
||||
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
@@ -1100,6 +1183,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
@@ -1194,6 +1279,10 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
@@ -1326,6 +1415,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
@@ -1378,6 +1469,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# 3.1 Encode ip_adapter_image
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
|
||||
# 4. set timesteps
|
||||
def denoising_value_valid(dnv):
|
||||
return isinstance(denoising_end, float) and 0 < dnv < 1
|
||||
@@ -1649,6 +1746,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
||||
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
added_cond_kwargs["image_embeds"] = image_embeds
|
||||
|
||||
if num_channels_unet == 9:
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user