Compare commits

...

30 Commits

Author SHA1 Message Date
Dhruv Nair
9c7e2f6721 Merge branch 'pipeline-interrupt' of https://github.com/huggingface/diffusers into pipeline-interrupt 2023-12-26 05:58:58 +00:00
Dhruv Nair
59a1524aad update 2023-12-26 05:58:52 +00:00
Sayak Paul
e23f6051a1 Merge branch 'main' into pipeline-interrupt 2023-12-26 08:56:18 +05:30
Dhruv Nair
6648af6d83 fix 2023-12-25 13:12:12 +00:00
Sayak Paul
a894d9f921 Merge branch 'main' into pipeline-interrupt 2023-12-25 12:02:07 +05:30
Sayak Paul
5fedea920f Merge branch 'main' into pipeline-interrupt 2023-12-20 10:05:00 +05:30
Dhruv Nair
852024a34a Merge branch 'main' into pipeline-interrupt 2023-12-19 04:00:14 +00:00
Dhruv Nair
2f2775d5f0 Merge branch 'pipeline-interrupt' of https://github.com/huggingface/diffusers into pipeline-interrupt 2023-12-04 15:31:41 +00:00
Dhruv Nair
52c45764ea fix quality issues 2023-12-04 15:30:53 +00:00
Dhruv Nair
e8bf891380 Merge branch 'main' into pipeline-interrupt 2023-12-04 12:58:31 +00:00
Patrick von Platen
fa213518ee Merge branch 'main' into pipeline-interrupt 2023-12-04 12:06:18 +01:00
Dhruv Nair
aefc1d72df Merge branch 'pipeline-interrupt' of https://github.com/huggingface/diffusers into pipeline-interrupt 2023-12-04 07:31:02 +00:00
Dhruv Nair
2e31ef21c4 Merge branch 'main' into pipeline-interrupt 2023-12-04 07:30:33 +00:00
Patrick von Platen
69f68d34ab Merge branch 'main' into pipeline-interrupt 2023-12-01 16:51:08 +01:00
Dhruv Nair
f4f3aed0a7 Merge branch 'main' into pipeline-interrupt 2023-11-30 17:47:55 +05:30
Dhruv Nair
dacaacc054 update 2023-11-30 10:38:00 +00:00
Dhruv Nair
5d415e970d Update docs/source/en/tutorials/interrupting_diffusion_process.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-11-30 15:49:37 +05:30
Dhruv Nair
f608d21779 Update docs/source/en/tutorials/interrupting_diffusion_process.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-11-30 15:49:29 +05:30
Dhruv Nair
3d69fd0087 add tutorial 2023-11-22 12:25:04 +00:00
Dhruv Nair
832f487a61 Merge branch 'main' into pipeline-interrupt 2023-11-22 11:38:00 +00:00
Dhruv Nair
36b4de2e21 add docs 2023-11-22 11:12:21 +00:00
Dhruv Nair
b693e254d7 Revert "make fix copies"
This reverts commit 914b35332b.
2023-11-21 12:23:46 +00:00
Dhruv Nair
914b35332b make fix copies 2023-11-21 12:07:46 +00:00
Dhruv Nair
fb832f7fdc Merge branch 'pipeline-interrupt' of https://github.com/huggingface/diffusers into pipeline-interrupt 2023-11-21 12:02:28 +00:00
Dhruv Nair
717cb97b83 add interrupt property 2023-11-21 12:01:17 +00:00
Dhruv Nair
6e61b0fb79 updatemsmq 2023-11-21 12:00:32 +00:00
Dhruv Nair
78201dd10f Merge branch 'main' into pipeline-interrupt 2023-11-21 11:33:23 +00:00
Dhruv Nair
b9a49ccfe5 Merge branch 'main' into pipeline-interrupt 2023-11-20 20:56:31 +05:30
Dhruv Nair
ac92d8513c add tests 2023-11-20 11:55:40 +00:00
Dhruv Nair
4683961b53 add interruptable pipelines 2023-11-20 10:23:20 +00:00
13 changed files with 422 additions and 0 deletions

View File

@@ -63,3 +63,42 @@ With callbacks, you can implement features such as dynamic CFG without having to
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
</Tip>
## Using Callbacks to interrupt the Diffusion Process
The following Pipelines support interrupting the diffusion process via callback
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.enable_model_cpu_offload()
num_inference_steps = 50
def interrupt_callback(pipe, i, t, callback_kwargs):
stop_idx = 10
if i == stop_idx:
pipe._interrupt = True
return callback_kwargs
pipe(
"A photo of a cat",
num_inference_steps=num_inference_steps,
callback_on_step_end=interrupt_callback,
)
```

View File

@@ -768,6 +768,10 @@ class StableDiffusionPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -909,6 +913,7 @@ class StableDiffusionPipeline(
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -986,6 +991,9 @@ class StableDiffusionPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

View File

@@ -832,6 +832,10 @@ class StableDiffusionImg2ImgPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -963,6 +967,7 @@ class StableDiffusionImg2ImgPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1041,6 +1046,9 @@ class StableDiffusionImg2ImgPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

View File

@@ -958,6 +958,10 @@ class StableDiffusionInpaintPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
@@ -1144,6 +1148,7 @@ class StableDiffusionInpaintPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1288,6 +1293,9 @@ class StableDiffusionInpaintPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

View File

@@ -849,6 +849,10 @@ class StableDiffusionXLPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1067,6 +1071,7 @@ class StableDiffusionXLPipeline(
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1196,6 +1201,9 @@ class StableDiffusionXLPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

View File

@@ -990,6 +990,10 @@ class StableDiffusionXLImg2ImgPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1221,6 +1225,7 @@ class StableDiffusionXLImg2ImgPipeline(
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1376,6 +1381,9 @@ class StableDiffusionXLImg2ImgPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

View File

@@ -1210,6 +1210,10 @@ class StableDiffusionXLInpaintPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1462,6 +1466,7 @@ class StableDiffusionXLInpaintPipeline(
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1684,6 +1689,8 @@ class StableDiffusionXLInpaintPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

View File

@@ -692,6 +692,58 @@ class StableDiffusionPipelineFastTests(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "hey"
num_inference_steps = 3
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
@slow
@require_torch_gpu

View File

@@ -320,6 +320,62 @@ class StableDiffusionImg2ImgPipelineFastTests(
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = "hey"
num_inference_steps = 3
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
@slow
@require_torch_gpu

View File

@@ -319,6 +319,64 @@ class StableDiffusionInpaintPipelineFastTests(
out_1 = sd_pipe(**inputs).images
assert np.abs(out_0 - out_1).max() < 1e-2
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = "hey"
num_inference_steps = 3
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
mask_image=inputs["mask_image"],
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
mask_image=inputs["mask_image"],
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
pipeline_class = StableDiffusionInpaintPipeline

View File

@@ -969,6 +969,58 @@ class StableDiffusionXLPipelineFastTests(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "hey"
num_inference_steps = 3
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
@slow
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):

View File

@@ -439,6 +439,64 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
> 1e-4
)
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = "hey"
num_inference_steps = 5
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
strength=0.8,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
strength=0.8,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase

View File

@@ -746,3 +746,63 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
image_slice1 = images[0, -3:, -3:, -1]
image_slice2 = images[1, -3:, -3:, -1]
assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = "hey"
num_inference_steps = 5
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
mask_image=inputs["mask_image"],
strength=0.8,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
mask_image=inputs["mask_image"],
strength=0.8,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)