mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 15:04:45 +08:00
Compare commits
2 Commits
fix-pipe-e
...
fix-addnoi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c112aaaca | ||
|
|
0f348e5405 |
@@ -485,69 +485,6 @@ image.save("sdxl_t2i.png")
|
||||
</div>
|
||||
</div>
|
||||
|
||||
You can use the IP-Adapter face model to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations.
|
||||
Weights are loaded with the same method used for the other IP-Adapters.
|
||||
|
||||
```python
|
||||
# Load ip-adapter-full-face_sd15.bin
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face model.
|
||||
|
||||
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
from diffusers.utils import load_image
|
||||
|
||||
noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=noise_scheduler,
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
|
||||
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
|
||||
image = pipeline(
|
||||
prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
|
||||
ip_adapter_image=image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=50, num_images_per_prompt=1, width=512, height=704,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ipadapter_full_face_output.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### LCM-Lora
|
||||
|
||||
|
||||
@@ -135,8 +135,8 @@ from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
"""
|
||||
if token_abstraction_dict:
|
||||
for key, value in token_abstraction_dict.items():
|
||||
@@ -157,8 +157,6 @@ tags:
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
widget:
|
||||
- text: '{validation_prompt if validation_prompt else instance_prompt}'
|
||||
---
|
||||
"""
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
|
||||
from ..models.embeddings import ImageProjection, Resampler
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -675,9 +675,6 @@ class UNet2DConditionLoadersMixin:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
elif "proj.3.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Full Face
|
||||
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
|
||||
@@ -747,32 +744,8 @@ class UNet2DConditionLoadersMixin:
|
||||
"norm.bias": state_dict["image_proj"]["norm.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
image_projection.load_state_dict(image_proj_state_dict)
|
||||
del image_proj_state_dict
|
||||
|
||||
elif "proj.3.weight" in state_dict["image_proj"]:
|
||||
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
|
||||
cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0]
|
||||
|
||||
image_projection = MLPProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
||||
)
|
||||
image_projection.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# load image projection layer weights
|
||||
image_proj_state_dict = {}
|
||||
image_proj_state_dict.update(
|
||||
{
|
||||
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
|
||||
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
|
||||
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
|
||||
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
|
||||
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
|
||||
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
|
||||
}
|
||||
)
|
||||
image_projection.load_state_dict(image_proj_state_dict)
|
||||
del image_proj_state_dict
|
||||
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
|
||||
@@ -461,18 +461,6 @@ class ImageProjection(nn.Module):
|
||||
return image_embeds
|
||||
|
||||
|
||||
class MLPProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
|
||||
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds: torch.FloatTensor):
|
||||
return self.norm(self.ff(image_embeds))
|
||||
|
||||
|
||||
class CombinedTimestepLabelEmbeddings(nn.Module):
|
||||
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
|
||||
@@ -92,43 +92,6 @@ def betas_for_alpha_bar(
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.FloatTensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Euler scheduler.
|
||||
@@ -165,10 +128,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -190,7 +149,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
timestep_type: str = "discrete", # can be "discrete" or "continuous"
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -205,17 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
# Close to 0 without being 0 so first sigma is not inf
|
||||
# FP16 smallest positive subnormal works well here
|
||||
self.alphas_cumprod[-1] = 2**-24
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
||||
|
||||
@@ -470,9 +420,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
@@ -509,9 +456,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
|
||||
@@ -182,25 +182,6 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_text_to_image_full_face(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.1706543, 0.1303711, 0.12573242, 0.21777344, 0.14550781, 0.14038086, 0.40820312, 0.41455078, 0.42529297]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -45,10 +45,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
def test_karras_sigmas(self):
|
||||
self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0)
|
||||
|
||||
def test_rescale_betas_zero_snr(self):
|
||||
for rescale_betas_zero_snr in [True, False]:
|
||||
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
Reference in New Issue
Block a user