mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-12 05:45:23 +08:00
Compare commits
1 Commits
main
...
z-image-di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b16351270 |
@@ -2455,18 +2455,22 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
if has_diffusion_model:
|
||||
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
|
||||
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
has_lora_unet = any(k.startswith("lora_unet_") or k.startswith("lora_unet__") for k in state_dict)
|
||||
if has_lora_unet:
|
||||
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.removeprefix("lora_unet__").removeprefix("lora_unet_"): v for k, v in state_dict.items()}
|
||||
|
||||
def convert_key(key: str) -> str:
|
||||
# ZImage has: layers, noise_refiner, context_refiner blocks
|
||||
# Keys may be like: layers_0_attention_to_q.lora_down.weight
|
||||
|
||||
if "." in key:
|
||||
base, suffix = key.rsplit(".", 1)
|
||||
else:
|
||||
base, suffix = key, ""
|
||||
suffix = ""
|
||||
for sfx in (".lora_down.weight", ".lora_up.weight", ".alpha"):
|
||||
if key.endswith(sfx):
|
||||
base = key[: -len(sfx)]
|
||||
suffix = sfx
|
||||
break
|
||||
else:
|
||||
base = key
|
||||
|
||||
# Protected n-grams that must keep their internal underscores
|
||||
protected = {
|
||||
@@ -2477,6 +2481,9 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
("to", "out"),
|
||||
# feed_forward
|
||||
("feed", "forward"),
|
||||
# noise and context refiner
|
||||
("noise", "refiner"),
|
||||
("context", "refiner"),
|
||||
}
|
||||
|
||||
prot_by_len = {}
|
||||
@@ -2501,7 +2508,7 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
i += 1
|
||||
|
||||
converted_base = ".".join(merged)
|
||||
return converted_base + (("." + suffix) if suffix else "")
|
||||
return converted_base + suffix
|
||||
|
||||
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
||||
|
||||
|
||||
@@ -658,7 +658,12 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
if prompt is None and prior_token_ids is None:
|
||||
if prompt is not None and prior_token_ids is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prior_token_ids is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined."
|
||||
)
|
||||
@@ -689,8 +694,8 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
|
||||
)
|
||||
|
||||
if prior_token_ids is not None and prompt_embeds is None and prompt is None:
|
||||
raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.")
|
||||
if prior_token_ids is not None and prompt_embeds is None:
|
||||
raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
|
||||
|
||||
import math
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -51,15 +51,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
solver_order (`int`, defaults to 2):
|
||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
rho (`float`, *optional*, defaults to 7.0):
|
||||
The rho parameter in the Karras sigma schedule. This was set to 7.0 in the EDM paper [1].
|
||||
solver_order (`int`, defaults to 2):
|
||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
@@ -96,19 +94,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_min: float = 0.002,
|
||||
sigma_max: float = 80.0,
|
||||
sigma_data: float = 0.5,
|
||||
sigma_schedule: Literal["karras", "exponential"] = "karras",
|
||||
sigma_schedule: str = "karras",
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
prediction_type: str = "epsilon",
|
||||
rho: float = 7.0,
|
||||
solver_order: int = 2,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: Literal["dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", # "zero", "sigma_min"
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
):
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -147,19 +145,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self) -> float:
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
return (self.config.sigma_max**2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self) -> int:
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> int:
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
@@ -276,11 +274,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -466,12 +460,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
|
||||
sigma_t = sigma
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
sample: torch.Tensor,
|
||||
sample: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -502,7 +497,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
sample: torch.Tensor,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -513,8 +508,6 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor to add to the original samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -545,7 +538,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -556,8 +549,6 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor to add to the original samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -618,7 +609,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
sample: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
One step for the third-order multistep DPMSolver.
|
||||
@@ -707,7 +698,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
@@ -728,7 +719,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -869,5 +860,5 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -281,86 +281,6 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# Should return 4 images (2 prompts × 2 images per prompt)
|
||||
self.assertEqual(len(images), 4)
|
||||
|
||||
def test_prompt_with_prior_token_ids(self):
|
||||
"""Test that prompt and prior_token_ids can be provided together.
|
||||
|
||||
When both are given, the AR generation step is skipped (prior_token_ids is used
|
||||
directly) and prompt is used to generate prompt_embeds via the glyph encoder.
|
||||
"""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
height, width = 32, 32
|
||||
|
||||
# Step 1: Run with prompt only to get prior_token_ids from AR model
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
prior_token_ids, _, _ = pipe.generate_prior_tokens(
|
||||
prompt="A photo of a cat",
|
||||
height=height,
|
||||
width=width,
|
||||
device=torch.device(device),
|
||||
generator=torch.Generator(device=device).manual_seed(0),
|
||||
)
|
||||
|
||||
# Step 2: Run with both prompt and prior_token_ids — should not raise
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
inputs_both = {
|
||||
"prompt": "A photo of a cat",
|
||||
"prior_token_ids": prior_token_ids,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
images = pipe(**inputs_both).images
|
||||
self.assertEqual(len(images), 1)
|
||||
self.assertEqual(images[0].shape, (3, 32, 32))
|
||||
|
||||
def test_check_inputs_rejects_invalid_combinations(self):
|
||||
"""Test that check_inputs correctly rejects invalid input combinations."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
|
||||
height, width = 32, 32
|
||||
|
||||
# Neither prompt nor prior_token_ids → error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.check_inputs(
|
||||
prompt=None,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=torch.randn(1, 16, 32),
|
||||
)
|
||||
|
||||
# prior_token_ids alone without prompt or prompt_embeds → error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.check_inputs(
|
||||
prompt=None,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prior_token_ids=torch.randint(0, 100, (1, 64)),
|
||||
)
|
||||
|
||||
# prompt + prompt_embeds together → error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.check_inputs(
|
||||
prompt="A cat",
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=torch.randn(1, 16, 32),
|
||||
)
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user