Compare commits

..

3 Commits

Author SHA1 Message Date
Dhruv Nair
5095916699 update 2023-12-26 15:36:38 +00:00
Dhruv Nair
40ea8ce23e update 2023-12-26 14:45:13 +00:00
Dhruv Nair
76c1007ae0 update 2023-12-26 14:42:14 +00:00
14 changed files with 40 additions and 438 deletions

View File

@@ -38,16 +38,21 @@ The following example demonstrates how to use a *MotionAdapter* checkpoint with
```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
# Load the motion adapter
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
# load SD 1.5 based finetuned model
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
scheduler = DDIMScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
model_id,
subfolder="scheduler",
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
)
pipe.scheduler = scheduler
@@ -70,6 +75,7 @@ output = pipe(
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
Here are some sample outputs:
@@ -88,7 +94,7 @@ Here are some sample outputs:
<Tip>
AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples.
AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the AnimateDiff checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
</Tip>
@@ -98,18 +104,25 @@ Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-mo
```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
# Load the motion adapter
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
# load SD 1.5 based finetuned model
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
pipe.load_lora_weights(
"guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out"
)
scheduler = DDIMScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
model_id,
subfolder="scheduler",
clip_sample=False,
beta_schedule="linear",
timestep_spacing="linspace",
steps_offset=1,
)
pipe.scheduler = scheduler
@@ -132,6 +145,7 @@ output = pipe(
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
<table>
@@ -160,21 +174,30 @@ Then you can use the following code to combine Motion LoRAs.
```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
# Load the motion adapter
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
# load SD 1.5 based finetuned model
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
pipe.load_lora_weights("diffusers/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe.load_lora_weights("diffusers/animatediff-motion-lora-pan-left", adapter_name="pan-left")
pipe.load_lora_weights(
"diffusers/animatediff-motion-lora-zoom-out", adapter_name="zoom-out",
)
pipe.load_lora_weights(
"diffusers/animatediff-motion-lora-pan-left", adapter_name="pan-left",
)
pipe.set_adapters(["zoom-out", "pan-left"], adapter_weights=[1.0, 1.0])
scheduler = DDIMScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
model_id,
subfolder="scheduler",
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
)
pipe.scheduler = scheduler
@@ -197,6 +220,7 @@ output = pipe(
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
<table>

View File

@@ -63,42 +63,3 @@ With callbacks, you can implement features such as dynamic CFG without having to
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
</Tip>
## Using Callbacks to interrupt the Diffusion Process
The following Pipelines support interrupting the diffusion process via callback
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.enable_model_cpu_offload()
num_inference_steps = 50
def interrupt_callback(pipe, i, t, callback_kwargs):
stop_idx = 10
if i == stop_idx:
pipe._interrupt = True
return callback_kwargs
pipe(
"A photo of a cat",
num_inference_steps=num_inference_steps,
callback_on_step_end=interrupt_callback,
)
```

View File

@@ -768,10 +768,6 @@ class StableDiffusionPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -913,7 +909,6 @@ class StableDiffusionPipeline(
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -991,9 +986,6 @@ class StableDiffusionPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

View File

@@ -832,10 +832,6 @@ class StableDiffusionImg2ImgPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -967,7 +963,6 @@ class StableDiffusionImg2ImgPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1046,9 +1041,6 @@ class StableDiffusionImg2ImgPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

View File

@@ -958,10 +958,6 @@ class StableDiffusionInpaintPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
@@ -1148,7 +1144,6 @@ class StableDiffusionInpaintPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1293,9 +1288,6 @@ class StableDiffusionInpaintPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

View File

@@ -849,10 +849,6 @@ class StableDiffusionXLPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1071,7 +1067,6 @@ class StableDiffusionXLPipeline(
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1201,9 +1196,6 @@ class StableDiffusionXLPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

View File

@@ -990,10 +990,6 @@ class StableDiffusionXLImg2ImgPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1225,7 +1221,6 @@ class StableDiffusionXLImg2ImgPipeline(
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1381,9 +1376,6 @@ class StableDiffusionXLImg2ImgPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

View File

@@ -1210,10 +1210,6 @@ class StableDiffusionXLInpaintPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1466,7 +1462,6 @@ class StableDiffusionXLInpaintPipeline(
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1689,8 +1684,6 @@ class StableDiffusionXLInpaintPipeline(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

View File

@@ -692,58 +692,6 @@ class StableDiffusionPipelineFastTests(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "hey"
num_inference_steps = 3
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
@slow
@require_torch_gpu

View File

@@ -320,62 +320,6 @@ class StableDiffusionImg2ImgPipelineFastTests(
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = "hey"
num_inference_steps = 3
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
@slow
@require_torch_gpu

View File

@@ -319,64 +319,6 @@ class StableDiffusionInpaintPipelineFastTests(
out_1 = sd_pipe(**inputs).images
assert np.abs(out_0 - out_1).max() < 1e-2
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = "hey"
num_inference_steps = 3
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
mask_image=inputs["mask_image"],
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
mask_image=inputs["mask_image"],
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
pipeline_class = StableDiffusionInpaintPipeline

View File

@@ -969,58 +969,6 @@ class StableDiffusionXLPipelineFastTests(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "hey"
num_inference_steps = 3
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
@slow
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):

View File

@@ -439,64 +439,6 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
> 1e-4
)
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = "hey"
num_inference_steps = 5
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
strength=0.8,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
strength=0.8,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase

View File

@@ -746,63 +746,3 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
image_slice1 = images[0, -3:, -3:, -1]
image_slice2 = images[1, -3:, -3:, -1]
assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = "hey"
num_inference_steps = 5
# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []
def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs
pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
mask_image=inputs["mask_image"],
strength=0.8,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images
# interrupt generation at step index
interrupt_step_idx = 1
def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True
return callback_kwargs
output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
mask_image=inputs["mask_image"],
strength=0.8,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images
# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]
# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)