mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
4 Commits
fix-mirror
...
cascade-fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b682be902 | ||
|
|
3d0bb51d53 | ||
|
|
4b72aae0cd | ||
|
|
33bbe58ea7 |
@@ -289,7 +289,9 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
@@ -321,10 +323,17 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
|
||||
argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -378,7 +387,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
|
||||
# 2. Encode caption
|
||||
if prompt_embeds is None and negative_prompt_embeds is None:
|
||||
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt(
|
||||
_, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
@@ -386,10 +395,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
)
|
||||
|
||||
# The pooled embeds from the prior are pooled again before being passed to the decoder
|
||||
prompt_embeds_pooled = (
|
||||
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds
|
||||
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
|
||||
if self.do_classifier_free_guidance
|
||||
else prompt_embeds_pooled
|
||||
)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
|
||||
@@ -155,14 +155,14 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
prior_num_inference_steps: int = 60,
|
||||
prior_timesteps: Optional[List[float]] = None,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
num_inference_steps: int = 12,
|
||||
decoder_timesteps: Optional[List[float]] = None,
|
||||
decoder_guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
@@ -187,10 +187,17 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
@@ -253,7 +260,6 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
prior_outputs = self.prior_pipe(
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
images=images,
|
||||
@@ -263,7 +269,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
guidance_scale=prior_guidance_scale,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
@@ -274,7 +282,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
)
|
||||
image_embeddings = prior_outputs.image_embeddings
|
||||
prompt_embeds = prior_outputs.get("prompt_embeds", None)
|
||||
prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None)
|
||||
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
|
||||
negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None)
|
||||
|
||||
outputs = self.decoder_pipe(
|
||||
image_embeddings=image_embeddings,
|
||||
@@ -283,7 +293,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
guidance_scale=decoder_guidance_scale,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
generator=generator,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
|
||||
@@ -64,7 +64,9 @@ class StableCascadePriorPipelineOutput(BaseOutput):
|
||||
|
||||
image_embeddings: Union[torch.FloatTensor, np.ndarray]
|
||||
prompt_embeds: Union[torch.FloatTensor, np.ndarray]
|
||||
prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
|
||||
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
|
||||
negative_prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
|
||||
|
||||
|
||||
class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
@@ -305,6 +307,16 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_pooled is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
|
||||
)
|
||||
|
||||
if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
|
||||
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
|
||||
raise ValueError(
|
||||
@@ -339,7 +351,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def get_t_condioning(self, t, alphas_cumprod):
|
||||
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
||||
s = torch.tensor([0.003])
|
||||
clamp_range = [0, 1]
|
||||
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
||||
@@ -558,7 +570,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
if len(alphas_cumprod) > 0:
|
||||
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod)
|
||||
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
|
||||
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
|
||||
else:
|
||||
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
|
||||
@@ -609,6 +621,18 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
) # float() as bfloat16-> numpy doesnt work
|
||||
|
||||
if not return_dict:
|
||||
return (latents, prompt_embeds, negative_prompt_embeds)
|
||||
return (
|
||||
latents,
|
||||
prompt_embeds,
|
||||
prompt_embeds_pooled,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled,
|
||||
)
|
||||
|
||||
return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds)
|
||||
return StableCascadePriorPipelineOutput(
|
||||
image_embeddings=latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
)
|
||||
|
||||
@@ -241,6 +241,39 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
def test_callback_inputs(self):
|
||||
super().test_callback_inputs()
|
||||
|
||||
# def test_callback_cfg(self):
|
||||
# pass
|
||||
# pass
|
||||
def test_stable_cascade_combined_prompt_embeds(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = StableCascadeCombinedPipeline(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A photograph of a shiba inu, wearing a hat"
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_embeds_pooled,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled,
|
||||
) = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
|
||||
generator = torch.Generator(device=device)
|
||||
|
||||
output_prompt = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
prior_num_inference_steps=1,
|
||||
output_type="np",
|
||||
generator=generator.manual_seed(0),
|
||||
)
|
||||
output_prompt_embeds = pipe(
|
||||
prompt=None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
num_inference_steps=1,
|
||||
prior_num_inference_steps=1,
|
||||
output_type="np",
|
||||
generator=generator.manual_seed(0),
|
||||
)
|
||||
|
||||
assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5
|
||||
|
||||
@@ -207,6 +207,45 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference()
|
||||
|
||||
def test_stable_cascade_decoder_prompt_embeds(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = StableCascadeDecoderPipeline(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image_embeddings = inputs["image_embeddings"]
|
||||
prompt = "A photograph of a shiba inu, wearing a hat"
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_embeds_pooled,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled,
|
||||
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
|
||||
generator = torch.Generator(device=device)
|
||||
|
||||
decoder_output_prompt = pipe(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
output_type="np",
|
||||
generator=generator.manual_seed(0),
|
||||
)
|
||||
decoder_output_prompt_embeds = pipe(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt=None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
num_inference_steps=1,
|
||||
output_type="np",
|
||||
generator=generator.manual_seed(0),
|
||||
)
|
||||
|
||||
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -273,6 +273,41 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
self.assertTrue(image_embed.shape == lora_image_embed.shape)
|
||||
|
||||
def test_stable_cascade_decoder_prompt_embeds(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A photograph of a shiba inu, wearing a hat"
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_embeds_pooled,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled,
|
||||
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
|
||||
generator = torch.Generator(device=device)
|
||||
|
||||
output_prompt = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
output_type="np",
|
||||
generator=generator.manual_seed(0),
|
||||
)
|
||||
output_prompt_embeds = pipe(
|
||||
prompt=None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
num_inference_steps=1,
|
||||
output_type="np",
|
||||
generator=generator.manual_seed(0),
|
||||
)
|
||||
|
||||
assert np.abs(output_prompt.image_embeddings - output_prompt_embeds.image_embeddings).max() < 1e-5
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user