mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
30 Commits
modular-re
...
pipeline-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c7e2f6721 | ||
|
|
59a1524aad | ||
|
|
e23f6051a1 | ||
|
|
6648af6d83 | ||
|
|
a894d9f921 | ||
|
|
5fedea920f | ||
|
|
852024a34a | ||
|
|
2f2775d5f0 | ||
|
|
52c45764ea | ||
|
|
e8bf891380 | ||
|
|
fa213518ee | ||
|
|
aefc1d72df | ||
|
|
2e31ef21c4 | ||
|
|
69f68d34ab | ||
|
|
f4f3aed0a7 | ||
|
|
dacaacc054 | ||
|
|
5d415e970d | ||
|
|
f608d21779 | ||
|
|
3d69fd0087 | ||
|
|
832f487a61 | ||
|
|
36b4de2e21 | ||
|
|
b693e254d7 | ||
|
|
914b35332b | ||
|
|
fb832f7fdc | ||
|
|
717cb97b83 | ||
|
|
6e61b0fb79 | ||
|
|
78201dd10f | ||
|
|
b9a49ccfe5 | ||
|
|
ac92d8513c | ||
|
|
4683961b53 |
@@ -63,3 +63,42 @@ With callbacks, you can implement features such as dynamic CFG without having to
|
||||
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Using Callbacks to interrupt the Diffusion Process
|
||||
|
||||
The following Pipelines support interrupting the diffusion process via callback
|
||||
|
||||
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
|
||||
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
|
||||
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
|
||||
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
|
||||
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
|
||||
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
|
||||
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.enable_model_cpu_offload()
|
||||
num_inference_steps = 50
|
||||
|
||||
def interrupt_callback(pipe, i, t, callback_kwargs):
|
||||
stop_idx = 10
|
||||
if i == stop_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
pipe(
|
||||
"A photo of a cat",
|
||||
num_inference_steps=num_inference_steps,
|
||||
callback_on_step_end=interrupt_callback,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -768,6 +768,10 @@ class StableDiffusionPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -909,6 +913,7 @@ class StableDiffusionPipeline(
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -986,6 +991,9 @@ class StableDiffusionPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -832,6 +832,10 @@ class StableDiffusionImg2ImgPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -963,6 +967,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1041,6 +1046,9 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -958,6 +958,10 @@ class StableDiffusionInpaintPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -1144,6 +1148,7 @@ class StableDiffusionInpaintPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1288,6 +1293,9 @@ class StableDiffusionInpaintPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -849,6 +849,10 @@ class StableDiffusionXLPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1067,6 +1071,7 @@ class StableDiffusionXLPipeline(
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1196,6 +1201,9 @@ class StableDiffusionXLPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -990,6 +990,10 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1221,6 +1225,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1376,6 +1381,9 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -1210,6 +1210,10 @@ class StableDiffusionXLInpaintPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1462,6 +1466,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1684,6 +1689,8 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -692,6 +692,58 @@ class StableDiffusionPipelineFastTests(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -320,6 +320,62 @@ class StableDiffusionImg2ImgPipelineFastTests(
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=5e-1)
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -319,6 +319,64 @@ class StableDiffusionInpaintPipelineFastTests(
|
||||
out_1 = sd_pipe(**inputs).images
|
||||
assert np.abs(out_0 - out_1).max() < 1e-2
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionInpaintPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
|
||||
@@ -969,6 +969,58 @@ class StableDiffusionXLPipelineFastTests(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -439,6 +439,64 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
> 1e-4
|
||||
)
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 5
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
|
||||
@@ -746,3 +746,63 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
image_slice1 = images[0, -3:, -3:, -1]
|
||||
image_slice2 = images[1, -3:, -3:, -1]
|
||||
assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 5
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
Reference in New Issue
Block a user