Compare commits

...

163 Commits

Author SHA1 Message Date
sayakpaul
601c506918 up 2023-11-06 12:56:06 +05:30
sayakpaul
f8a5e172cf up 2023-11-06 12:49:57 +05:30
sayakpaul
47e9219450 pop 2023-11-06 12:34:47 +05:30
sayakpaul
04d83c209d byte tensor 2023-11-06 12:20:38 +05:30
sayakpaul
0d81b2dab4 debug 2023-11-06 12:18:22 +05:30
sayakpaul
e67ddf8d13 debug 2023-11-06 12:17:09 +05:30
sayakpaul
cdbbc7d5b7 fix 2023-11-06 12:10:02 +05:30
sayakpaul
81fb265a08 also serialize the states. 2023-11-06 12:09:09 +05:30
sayakpaul
21c8c433a3 fix: scheduler assignment. 2023-11-06 11:58:35 +05:30
sayakpaul
3479e5311d workflow_copy 2023-11-06 11:57:22 +05:30
sayakpaul
f8cad5dc4a debug 2023-11-06 11:50:24 +05:30
sayakpaul
d75f8a537c debug 2023-11-06 11:36:01 +05:30
sayakpaul
4b94889652 generator 2023-11-06 11:33:00 +05:30
sayakpaul
35a6538343 debug generator 2023-11-06 11:30:25 +05:30
sayakpaul
1b65ff770c disable custom pop 2023-11-06 11:25:30 +05:30
sayakpaul
857c65bb56 make Workflow more lightweight. 2023-11-06 11:19:28 +05:30
sayakpaul
b5752ec4bf resolve conflicts 2023-11-06 11:04:34 +05:30
sayakpaul
157405436b update doc. 2023-11-02 22:26:04 +05:30
sayakpaul
d69f3079a8 Merge branch 'main' into feat/workflows 2023-11-02 22:20:45 +05:30
sayakpaul
407e669fca better error message 2023-11-02 22:16:42 +05:30
sayakpaul
5e00fcc153 generator 2023-11-02 13:20:17 +05:30
sayakpaul
1dd8cf5abe generator 2023-11-02 13:03:32 +05:30
sayakpaul
10739166e2 fix: tests. 2023-11-02 11:55:51 +05:30
sayakpaul
c1d8b882ee use return_worflow at the end. 2023-11-02 10:59:51 +05:30
sayakpaul
6697144a7d fix: toc 2023-11-02 10:58:29 +05:30
sayakpaul
6af69f5639 Merge branch 'main' into feat/workflows 2023-11-02 10:40:28 +05:30
sayakpaul
c28ea5e6c6 separate api page 2023-11-02 10:39:12 +05:30
sayakpaul
cb8902394d clean docs. 2023-11-02 10:35:41 +05:30
Sayak Paul
5c8d5df564 Apply suggestions from code review
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-11-02 10:31:30 +05:30
sayakpaul
474df7a6b1 clean 2023-11-02 10:27:28 +05:30
sayakpaul
047cede64a clean up tests 2023-10-29 17:28:03 +05:30
sayakpaul
ce9e17c547 handle list of generators. 2023-10-29 17:20:48 +05:30
sayakpaul
abdb4ffbc6 more fixes. 2023-10-29 17:11:39 +05:30
sayakpaul
87174fc208 load from a workflow directly 2023-10-29 17:09:36 +05:30
sayakpaul
425e75bc79 remove prompt_embeds. 2023-10-29 17:06:28 +05:30
sayakpaul
f14cf8f4aa boom 2023-10-29 13:47:05 +05:30
sayakpaul
2f7725edf4 style 2023-10-29 11:21:40 +05:30
sayakpaul
a11ce647de fix 2023-10-29 11:21:19 +05:30
sayakpaul
6f65f3ad3e fix 2023-10-29 11:07:02 +05:30
sayakpaul
cbc835b870 correct output. 2023-10-29 10:59:31 +05:30
sayakpaul
8550a86a17 done 2023-10-29 10:57:52 +05:30
sayakpaul
477cc9a82a remove print 2023-10-29 10:56:23 +05:30
sayakpaul
28c8e93179 fix 2023-10-29 10:48:24 +05:30
sayakpaul
3f90e07228 fix: 2023-10-29 10:46:46 +05:30
sayakpaul
67f7757048 fix 2023-10-29 10:39:40 +05:30
sayakpaul
0e40d6ffd6 fix 2023-10-29 10:37:52 +05:30
sayakpaul
3eab48f883 stringify device. 2023-10-29 10:31:58 +05:30
sayakpaul
b3a675288d corrections to doc. 2023-10-29 10:31:13 +05:30
sayakpaul
9099f51c5e fix? 2023-10-29 10:21:34 +05:30
sayakpaul
5dcd8c541e debug 2023-10-29 10:20:20 +05:30
sayakpaul
ac73c86610 debug 2023-10-29 10:19:58 +05:30
sayakpaul
9836b61fa8 debug 2023-10-29 10:19:09 +05:30
sayakpaul
1829b9485c debug 2023-10-29 10:18:07 +05:30
sayakpaul
6e514c2b6a debug 2023-10-29 10:13:37 +05:30
sayakpaul
b5888b4704 typo 2023-10-29 10:11:43 +05:30
sayakpaul
003220ba36 typo 2023-10-29 10:10:14 +05:30
sayakpaul
f221631d3c typo 2023-10-29 10:08:55 +05:30
sayakpaul
adcaba0a23 typo 2023-10-29 10:08:14 +05:30
sayakpaul
bd1f78e6cd typo 2023-10-29 10:05:19 +05:30
sayakpaul
ecfa79b673 fix: generator 2023-10-29 10:04:48 +05:30
sayakpaul
ab1d58872b debug generator 2023-10-29 09:56:57 +05:30
sayakpaul
020b4a4ad7 fix: generator update 2023-10-29 09:38:47 +05:30
sayakpaul
e510c3d4d5 fix import 2023-10-29 09:35:41 +05:30
sayakpaul
f0418e8896 serialize scheduler info too. 2023-10-29 09:27:27 +05:30
sayakpaul
c4eebd9c1a randomize seed. 2023-10-29 09:05:08 +05:30
sayakpaul
84de851116 change the path 2023-10-28 10:11:55 +05:30
sayakpaul
eae28adf40 some more notes 2023-10-28 10:08:54 +05:30
Sayak Paul
450198061e Apply suggestions from code review
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-10-28 10:04:17 +05:30
sayakpaul
be3ff851c6 fix: test 2023-10-27 18:48:07 +05:30
sayakpaul
c3792ba3e0 fix: 2023-10-27 15:52:44 +05:30
sayakpaul
7933f2ac18 fix: 2023-10-27 15:44:55 +05:30
sayakpaul
1b7a7c27d3 fix: 2023-10-27 15:39:31 +05:30
sayakpaul
4689d759fd fix test 2023-10-27 15:37:42 +05:30
sayakpaul
d06cc7eb6f fix: test 2023-10-27 15:30:54 +05:30
sayakpaul
b5780adf46 fix: test 2023-10-27 15:29:09 +05:30
sayakpaul
d07d3f1642 fix: tokenizer 2023-10-27 15:28:30 +05:30
sayakpaul
e5a69ff497 fix: import 2023-10-27 15:25:14 +05:30
sayakpaul
43846e14a1 fix: import 2023-10-27 15:22:57 +05:30
sayakpaul
7fa7259ded fix: doc 2023-10-27 15:17:38 +05:30
sayakpaul
b576a1dc47 fix: doc 2023-10-27 15:06:27 +05:30
sayakpaul
9fc37d9dc7 add: hub related staging tests 2023-10-27 14:54:45 +05:30
sayakpaul
0a298f55fc fix: _name_or_path 2023-10-27 14:44:14 +05:30
sayakpaul
7382344fed add: test 2023-10-27 14:33:08 +05:30
sayakpaul
93ce75f4e5 add: entry to toctree 2023-10-27 14:25:29 +05:30
sayakpaul
b42bcb86ea add: doc 2023-10-27 14:22:48 +05:30
sayakpaul
47fe2d0d2e fix pop call. 2023-10-27 13:42:41 +05:30
sayakpaul
6a59219e81 more rigorous 2023-10-27 13:41:14 +05:30
sayakpaul
f8eff79b82 fxi 2023-10-27 13:31:13 +05:30
sayakpaul
d01f2f678a rigid call arguments. 2023-10-27 13:30:10 +05:30
sayakpaul
ae0a268f8e add to controlnet 2023-10-27 13:24:16 +05:30
sayakpaul
fce889f19b make push_to_hub False by default 2023-10-27 13:00:27 +05:30
sayakpaul
01b3f64549 fix: save 2023-10-27 12:58:51 +05:30
sayakpaul
53c65e3d19 add: support for serializing generator device. 2023-10-27 12:12:34 +05:30
sayakpaul
6d48af1a46 add back the comment 2023-10-27 12:08:05 +05:30
sayakpaul
66d3fd6732 remove filename stuff from config utils. 2023-10-27 12:06:58 +05:30
sayakpaul
0cc943b757 Merge branch 'main' into feat/workflows 2023-10-27 12:00:28 +05:30
Sayak Paul
41a74f8474 Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-10-27 12:00:01 +05:30
sayakpaul
9a8b5f7cd8 address patrick's comments./ 2023-10-19 10:59:58 +05:30
sayakpaul
ebf2addb86 Merge branch 'main' into feat/workflows 2023-10-19 10:56:12 +05:30
sayakpaul
03bfdff59a Empty-Commit 2023-10-18 14:40:48 +05:30
sayakpaul
ff5cd58aa1 support basic lora only for non-peft for now. 2023-10-18 12:39:41 +05:30
sayakpaul
c1c11a6747 fix: lora 2023-10-18 12:25:15 +05:30
sayakpaul
49e06fdd2c fix: lora population. 2023-10-18 12:23:41 +05:30
sayakpaul
f874578a4d fix: lora population. 2023-10-18 12:23:03 +05:30
Sayak Paul
c0e1c6348f Merge branch 'main' into feat/workflows 2023-10-18 12:15:17 +05:30
sayakpaul
231e8314dd feat: support passing filename 2023-10-18 11:50:43 +05:30
sayakpaul
21e5bb65e5 replace the __call__ attribute of the class, not the instance 2023-10-18 11:40:09 +05:30
sayakpaul
74e766c0b8 quality 2023-10-18 11:20:36 +05:30
sayakpaul
452bf4fa05 hmm almost 2023-10-18 11:16:44 +05:30
sayakpaul
55c47bc751 let's see 2023-10-18 10:56:25 +05:30
sayakpaul
eaae2df25c debugging. 2023-10-17 18:52:35 +05:30
sayakpaul
d612b5435b debug 2023-10-17 18:49:03 +05:30
sayakpaul
1dc9854968 copying helps? 2023-10-17 18:44:59 +05:30
sayakpaul
18a756f0ad style. 2023-10-17 16:57:30 +05:30
Sayak Paul
4993c8ba63 Merge branch 'main' into feat/workflows 2023-10-17 16:54:09 +05:30
sayakpaul
ed9acd6426 remove print 2023-10-17 16:36:05 +05:30
sayakpaul
a69e3d15c1 debug. 2023-10-17 16:34:09 +05:30
sayakpaul
4731f65ed4 debug. 2023-10-17 16:29:14 +05:30
sayakpaul
9adaa1739a debug 2023-10-17 16:16:10 +05:30
sayakpaul
af282b7f4b debug 2023-10-17 16:13:32 +05:30
sayakpaul
3d6637b65d partial 2023-10-17 16:10:43 +05:30
sayakpaul
800b7a0fda morr 2023-10-17 16:04:53 +05:30
sayakpaul
91c1c1f1f6 apply styling 2023-10-17 16:03:44 +05:30
sayakpaul
21d19bbc44 workflow_filename -> filename 2023-10-17 16:03:25 +05:30
sayakpaul
9d0bcd48f1 debug 2023-10-17 15:50:45 +05:30
sayakpaul
9ee8b0a070 debug 2023-10-17 15:49:34 +05:30
sayakpaul
b5fd337875 debug 2023-10-17 15:47:41 +05:30
sayakpaul
f08f40bde1 debug 2023-10-17 15:46:18 +05:30
sayakpaul
c5ff8cd943 debug 2023-10-17 15:45:05 +05:30
sayakpaul
319456049a seed. 2023-10-17 15:28:03 +05:30
sayakpaul
f6c0878fc6 callables should not be serialized too. 2023-10-17 15:24:26 +05:30
sayakpaul
e590b73cc1 override pop too for feature compatibility 2023-10-17 15:22:59 +05:30
sayakpaul
aa7839c1c7 pop from internal dict too. 2023-10-17 15:22:22 +05:30
sayakpaul
fc609e308f more fix 2023-10-17 15:19:55 +05:30
sayakpaul
7b85bfe3e5 fix: signature 2023-10-17 15:18:31 +05:30
sayakpaul
eff03fd054 override method 2023-10-17 15:16:21 +05:30
sayakpaul
73dcc17ff1 remove unneeded comment 2023-10-17 15:05:54 +05:30
sayakpaul
b149800269 remove unneeded comment 2023-10-17 15:04:01 +05:30
sayakpaul
2d1cd20afe stronger check 2023-10-17 15:03:32 +05:30
sayakpaul
45c5656bad make config_name a part of the dict. 2023-10-17 14:41:25 +05:30
sayakpaul
2b48d8572d save_pretrained() to workflow so that it has push_to_hub 2023-10-17 14:31:15 +05:30
sayakpaul
e710121a9b save_pretrained() to workflow so that it has push_to_hub 2023-10-17 14:28:42 +05:30
sayakpaul
50769e058b remove torch tensor warning as it might complicate things 2023-10-17 14:19:34 +05:30
sayakpaul
0bd97735dc debug 2023-10-17 14:16:43 +05:30
sayakpaul
930ca765f4 debug 2023-10-17 14:13:34 +05:30
sayakpaul
807c2ca13f debug 2023-10-17 14:04:22 +05:30
sayakpaul
ad725977cc patch call. 2023-10-17 13:52:54 +05:30
sayakpaul
97ae043f8a update docstrings. 2023-10-17 13:46:49 +05:30
sayakpaul
1ab81a6db4 update progress. 2023-10-17 13:05:14 +05:30
sayakpaul
a6a0277713 include pipeline name in the workflow 2023-10-17 12:28:45 +05:30
sayakpaul
ba0b1e857c improve docstring 2023-10-17 12:24:35 +05:30
sayakpaul
d8e6f38db4 change method desc. 2023-10-17 12:22:40 +05:30
sayakpaul
ef94a008d2 handle torch.tensor. 2023-10-17 12:22:02 +05:30
sayakpaul
29d0aa887c remove components from workflows. 2023-10-17 12:14:43 +05:30
sayakpaul
5f19b66d5a resolve conflicts 2023-10-15 12:04:25 +05:30
sayakpaul
96c55d4c5a fix 2023-10-15 11:58:20 +05:30
sayakpaul
e3611e325b properly set lora_info 2023-10-15 11:40:00 +05:30
sayakpaul
d5d31e0ae3 add: support for lora. 2023-10-15 11:09:03 +05:30
sayakpaul
a62b77ff6e include todos. 2023-10-15 10:52:11 +05:30
sayakpaul
a8a1378987 fix 2023-10-15 10:48:40 +05:30
Sayak Paul
e8e09e48ea Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-08-30 14:14:26 +05:30
Sayak Paul
ac295055ce add unifinished implementation of _update_call() 2023-08-29 14:45:09 +05:30
Sayak Paul
43e4e841f9 add: workflows. 2023-08-29 13:59:22 +05:30
10 changed files with 835 additions and 4 deletions

View File

@@ -19,6 +19,8 @@
title: Train a diffusion model
- local: tutorials/using_peft_for_inference
title: Inference with PEFT
- local: tutorials/workflows
title: Working with workflows
title: Tutorials
- sections:
- sections:
@@ -178,6 +180,8 @@
title: Logging
- local: api/outputs
title: Outputs
- local: api/workflows
title: Shareable workflows
title: Main Classes
- sections:
- local: api/models/overview

View File

@@ -0,0 +1,7 @@
# Shareable workflows
Workflows provide a simple mechanism to share your 🤗 Diffusers pipeline call arguments and scheduler configuration, making it easier to reproduce results.
## Workflow
[[autodoc]] workflow_utils.Workflow

View File

@@ -0,0 +1,333 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Working with workflows
<Tip warning={true}>
🧪 Workflow is experimental and its APIs can change in the future.
</Tip>
Workflows provide a simple mechanism to share your pipeline call arguments and scheduler configuration, making it easier to reproduce results.
## Serializing a workflow
A [`Workflow`] object provides all the argument values in the `__call__()` of a pipeline. Add `return_workflow=True` to return a `Workflow` object.
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
outputs = pipeline(
"A painting of a horse",
num_inference_steps=15,
generator=torch.manual_seed(0),
return_workflow=True
)
workflow = outputs.workflow
```
<Tip warning={true}>
It's mandatory to specify the `generator` when `return_workflow` is set to True.
</Tip>
If you look at this specific workflow, you'll see values like the number of inference steps, guidance scale, and height and width as well as the scheduler details:
```bash
{'prompt': 'A painting of a horse',
'height': None,
'width': None,
'num_inference_steps': 15,
'guidance_scale': 7.5,
'negative_prompt': None,
'eta': 0.0,
'latents': None,
'prompt_embeds': None,
'negative_prompt_embeds': None,
'output_type': 'pil',
'return_dict': True,
'callback': None,
'callback_steps': 1,
'cross_attention_kwargs': None,
'guidance_rescale': 0.0,
'clip_skip': None,
'generator_seed': 0,
'generator_device': device(type='cpu'),
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('prediction_type', 'epsilon'),
('timestep_spacing', 'leading'),
('steps_offset', 1),
('_use_default_values', ['prediction_type', 'timestep_spacing']),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.6.0'),
('clip_sample', False)])}
```
Once you have generated a workflow object, you can serialize it with [`~Workflow.save_workflow`]:
```python
outputs.workflow.save_workflow("my-simple-workflow-sd")
```
By default, your workflows are saved as `diffusion_workflow.json`, but you can give them a specific name with the `filename` argument:
```python
outputs.workflow.save_workflow("my-simple-workflow-sd", filename="my_workflow.json")
```
You can also set `push_to_hub=True` in [`~Workflow.save_workflow`] to directly push the workflow object to the Hub.
## Loading a workflow
You can load a workflow in a pipeline with [`~DiffusionPipeline.load_workflow`]:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
```
Once the pipeline is loaded with the desired workflow, it's ready to be called:
```python
image = pipeline().images[0]
```
By default, while loading a workflow, the scheduler of the underlying pipeline from the workflow isn't modified but you can change it by adding `load_scheduler=True`:
```
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd", load_scheduler=True)
```
This is particularly useful if you have changed the scheduler after loading a pipeline.
You can also override the pipeline call arguments. For example, to add a `negative_prompt`:
```python
image = pipeline(negative_prompt="bad quality").images[0]
```
Loading from a workflow is possible by specifying the `filename` argument inside the [`DiffusionPipeline.load_workflow`] method.
A workflow doesn't necessarily have to be used with the same pipeline that generated it. You can use it with a different pipeline too:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
image = pipeline().images[0]
```
However, make sure to thoroughly inspect the values you are calling the pipeline with, in this case.
Loading from a local workflow is also possible:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("path_to_local_dir")
image = pipeline().images[0]
```
Alternatively, if you want to load a workflow file and populate the pipeline arguments manually:
```python
from diffusers import DiffusionPipeline
import json
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
with open("path_to_workflow_file.json") as f:
workflow = json.load(f)
pipeline.load_workflow(workflow)
images = pipeline().images[0]
```
## Unsupported serialization types
Image-to-image pipelines like [`StableDiffusionControlNetPipeline`] accept one or more images in their `call` method. Currently, workflows don't support serializing `call` arguments that are of type `PIL.Image.Image` or `List[PIL.Image.Image]`. To make those pipelines work with workflows, you need to pass the images manually.
Let's say you generated the workflow below:
```python
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import numpy as np
import torch
import cv2
from PIL import Image
# download an image
image = load_image(
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)
image = np.array(image)
# get canny image
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
# generate image
generator = torch.manual_seed(0)
outputs = pipe(
prompt="futuristic-looking office",
image=canny_image,
num_inference_steps=20,
generator=generator,
return_workflow=True
)
workflow = outputs.workflow
```
If you look at the workflow, you'll see the image that was passed to the pipeline isn't included:
```bash
{'prompt': 'futuristic-looking office',
'height': None,
'width': None,
'num_inference_steps': 20,
'guidance_scale': 7.5,
'negative_prompt': None,
'eta': 0.0,
'latents': None,
'prompt_embeds': None,
'negative_prompt_embeds': None,
'output_type': 'pil',
'return_dict': True,
'callback': None,
'callback_steps': 1,
'cross_attention_kwargs': None,
'controlnet_conditioning_scale': 1.0,
'guess_mode': False,
'control_guidance_start': 0.0,
'control_guidance_end': 1.0,
'clip_skip': None,
'generator_seed': 0,
'generator_device': 'cpu',
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('solver_order', 2),
('prediction_type', 'epsilon'),
('thresholding', False),
('dynamic_thresholding_ratio', 0.995),
('sample_max_value', 1.0),
('predict_x0', True),
('solver_type', 'bh2'),
('lower_order_final', True),
('disable_corrector', []),
('solver_p', None),
('use_karras_sigmas', False),
('timestep_spacing', 'linspace'),
('steps_offset', 1),
('_use_default_values',
['lower_order_final',
'sample_max_value',
'solver_p',
'dynamic_thresholding_ratio',
'thresholding',
'solver_type',
'prediction_type',
'predict_x0',
'use_karras_sigmas',
'disable_corrector',
'timestep_spacing',
'solver_order']),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.6.0'),
('clip_sample', False)])}
```
Let's serialize the workflow and reload the pipeline to see what happens when you try to use it.
```python
workflow.save_workflow("my-simple-workflow-sd", filename="controlnet_simple.json", push_to_hub=True)
```
Then load the workflow into [`StableDiffusionControlNetPipeline`]:
```python
# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.load_workflow("sayakpaul/my-simple-workflow-sd", filename="controlnet_simple.json")
```
If you try to generate an image now, it'll return the following error:
```bash
TypeError: image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is <class 'NoneType'>
```
To resolve the error, manually pass the conditioning image `canny_image`:
```python
image = pipe(image=canny_image).images[0]
```
Other unsupported serialization types include:
* LoRA checkpoints: any information from LoRA checkpoints that might be loaded into a pipeline isn't serialized. Workflows generated from pipelines loaded with a LoRA checkpoint should be handled cautiously! You should ensure the LoRA checkpoint is loaded into the pipeline first before loading the corresponding workflow.
* Call arguments including the following types: `torch.Tensor`, `np.ndarray`, `Callable`, `PIL.Image.Image`, and `List[PIL.Image.Image]`.

View File

@@ -752,6 +752,7 @@ class StableDiffusionControlNetPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
return_workflow: bool = False,
clip_skip: Optional[int] = None,
):
r"""
@@ -824,6 +825,8 @@ class StableDiffusionControlNetPipeline(
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
return_workflow (`bool`, *optional*, defaults to `False`):
Whether to return used pipeline call arguments.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -837,6 +840,14 @@ class StableDiffusionControlNetPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# We do this first to capture the "True" call values. If we do this at a later point in time,
# we cannot ensure that the call values weren't changed during the process.
workflow = None
if return_workflow:
if generator is None:
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
workflow = self.populate_workflow_from_pipeline()
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1075,6 +1086,11 @@ class StableDiffusionControlNetPipeline(
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
outputs = (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
if return_workflow:
outputs += (workflow,)
return outputs
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)

View File

@@ -22,6 +22,7 @@ import re
import sys
import warnings
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
@@ -54,7 +55,9 @@ from ..utils import (
logging,
numpy_to_pil,
)
from ..utils.constants import WORKFLOW_NAME
from ..utils.torch_utils import is_compiled_module
from ..workflow_utils import _NON_CALL_ARGUMENTS, Workflow
if is_transformers_available():
@@ -64,6 +67,7 @@ if is_transformers_available():
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
@@ -2075,3 +2079,117 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for module in modules:
module.set_attention_slice(slice_size)
def populate_workflow_from_pipeline(self) -> Dict:
r"""Populates the call arguments in a dictionary.
Returns:
[`Workflow`]: A dictionary containing the details of the pipeline call arguments and (optionally) LoRA
checkpoint details.
"""
# A `Workflow` object is an extended Python dictionary. So, all regular dictionary methods
# apply to it.
workflow = Workflow()
signature = inspect.signature(self.__call__)
argument_names = [param.name for param in signature.parameters.values()]
call_arg_values = inspect.getargvalues(inspect.currentframe().f_back).locals
# Populate call arguments.
call_arguments = {
arg: call_arg_values[arg]
for arg in argument_names
if arg != "return_workflow"
and "image" not in arg
and not isinstance(call_arg_values[arg], (torch.Tensor, np.ndarray, Callable))
}
workflow.update(call_arguments)
# Handle generator device and seed.
generator = workflow["generator"]
if isinstance(generator, list):
for g in generator:
if "generator_seed" not in workflow:
workflow.update({"generator_seed": [g.initial_seed()]})
workflow.update({"generator_device": [str(g.device)]})
workflow.update({"generator_state": g.get_state().numpy().tolist()})
else:
workflow["generator_seed"].append(g.initial_seed())
workflow["generator_device"].append(g.device)
workflow["generator_state"].append(g.get_state().numpy().tolist())
else:
workflow.update({"generator_seed": generator.initial_seed()})
workflow.update({"generator_device": str(generator.device)})
workflow.update({"generator_state": generator.get_state().numpy().tolist()})
workflow.pop("generator")
# Handle pipeline-level things.
if hasattr(self, "config") and hasattr(self.config, "_name_or_path"):
pipeline_config_name_or_path = self.config._name_or_path
else:
pipeline_config_name_or_path = None
workflow["_name_or_path"] = pipeline_config_name_or_path
workflow["scheduler_config"] = self.scheduler.config
return workflow
def load_workflow(
self,
workflow_id_or_path: Union[str, dict],
filename: Optional[str] = None,
):
r"""Loads a workflow from the Hub or from a local path. Also patches the pipeline call arguments with values from the
workflow.
Args:
workflow_id_or_path (`str` or `dict`):
Can be either:
- A string, the workflow id (for example `sayakpaul/sdxl-workflow`) of a workflow hosted on the
Hub.
- A path to a directory (for example `./my_workflow_directory`) containing the workflow file with
[`Workflow.save_workflow`] or [`Workflow.push_to_hub`].
- A Python dictionary.
filename (`str`, *optional*):
Optional name of the workflow file to load. Especially useful when working with multiple workflow
files.
"""
filename = filename or WORKFLOW_NAME
# Load workflow.
if not isinstance(workflow_id_or_path, dict):
if os.path.isdir(workflow_id_or_path):
workflow_filepath = os.path.join(workflow_id_or_path, filename)
elif os.path.isfile(workflow_id_or_path):
workflow_filepath = workflow_id_or_path
else:
workflow_filepath = hf_hub_download(repo_id=workflow_id_or_path, filename=filename)
workflow = self._dict_from_json_file(workflow_filepath)
else:
workflow = workflow_id_or_path
# We make a copy of the original workflow and operate on it.
workflow_copy = dict(workflow.items())
# Handle generator.
seed = workflow_copy.pop("generator_seed")
device = workflow_copy.pop("generator_device", "cpu")
last_known_state = workflow_copy.pop("generator_state")
if isinstance(seed, list):
generator = [
torch.Generator(device=d).manual_seed(s).set_state(torch.from_numpy(np.array(lst)).byte())
for s, d, lst in zip(seed, device, last_known_state)
]
else:
last_known_state = torch.from_numpy(np.array(last_known_state)).byte()
generator = torch.Generator(device=device).manual_seed(seed).set_state(last_known_state)
workflow_copy.update({"generator": generator})
# Handle non-call arguments.
final_call_args = {k: v for k, v in workflow_copy.items() if k not in _NON_CALL_ARGUMENTS}
# Handle the call here.
partial_call = partial(self.__call__, **final_call_args)
setattr(self.__class__, "__call__", partial_call)

View File

@@ -19,10 +19,13 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
workflow (`dict`):
Dictionary containing pipeline component configurations and call arguments
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
workflow: Optional[dict] = None
if is_flax_available():

View File

@@ -623,6 +623,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
return_workflow: bool = False,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -677,6 +678,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
return_workflow (`bool`, *optional*, defaults to `False`):
Whether to return used pipeline call arguments.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -699,6 +702,13 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# We do this first to capture the "True" call values. If we do this at a later point in time,
# we cannot ensure that the call values weren't changed during the process.
workflow = None
if return_workflow:
if generator is None:
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
workflow = self.populate_workflow_from_pipeline()
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
@@ -855,6 +865,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
outputs = (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
if return_workflow:
outputs += (workflow,)
return outputs
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)

View File

@@ -14,6 +14,7 @@
import importlib
import os
import numpy as np
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from packaging import version
@@ -32,11 +33,13 @@ FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
WORKFLOW_NAME = "diffusion_workflow.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
MAX_SEED = np.iinfo(np.int32).max
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

View File

@@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for managing workflows."""
import json
import os
from pathlib import PosixPath
from typing import Union
import numpy as np
from huggingface_hub import create_repo
from . import __version__
from .utils import PushToHubMixin, logging
from .utils.constants import WORKFLOW_NAME
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_NON_CALL_ARGUMENTS = {"_name_or_path", "scheduler_config", "_class_name", "_diffusers_version"}
class Workflow(dict, PushToHubMixin):
"""Class sub-classing from native Python dict to have support for interacting with the Hub."""
config_name = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.config_name = WORKFLOW_NAME
self._internal_dict = {}
def __setitem__(self, __key, __value):
self._internal_dict[__key] = __value
return super().__setitem__(__key, __value)
def update(self, __m, **kwargs):
self._internal_dict.update(__m, **kwargs)
super().update(__m, **kwargs)
def pop(self, key, *args):
self._internal_dict.pop(key, *args)
return super().pop(key, *args)
# Copied from diffusers.configuration_utils.ConfigMixin.to_json_string
def to_json_string(self) -> str:
"""
Serializes the configuration instance to a JSON string.
Returns:
`str`:
String containing all the attributes that make up the configuration instance in JSON format.
"""
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
elif isinstance(value, PosixPath):
value = str(value)
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def save_workflow(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
filename: str = WORKFLOW_NAME,
**kwargs,
):
"""
Saves a workflow to a directory.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the workflow JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
filename (`str`, *optional*, defaults to `workflow.json`):
Optional filename to use to serialize the workflow JSON.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.config_name = filename
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
output_config_file = os.path.join(save_directory, self.config_name)
with open(output_config_file, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
logger.info(f"Configuration saved in {output_config_file}")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
filename: str = WORKFLOW_NAME,
**kwargs,
):
"""
Saves a workflow to a directory. This internally calls [`Workflow.save_workflow`], This method exists to have
feature parity with [`PushToHubMixin.push_to_hub`].
Args:
save_directory (`str` or `os.PathLike`):
Directory where the workflow JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
filename (`str`, *optional*, defaults to `workflow.json`):
Optional filename to use to serialize the workflow JSON.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.save_workflow(
save_directory=save_directory,
push_to_hub=push_to_hub,
filename=filename,
**kwargs,
)

View File

@@ -0,0 +1,171 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import unittest
import uuid
import numpy as np
import torch
from huggingface_hub import delete_repo, hf_hub_download
from test_utils import TOKEN, USER, is_staging_test
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.utils.constants import WORKFLOW_NAME
from diffusers.utils.testing_utils import torch_device
from diffusers.workflow_utils import Workflow
class WorkflowFastTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=64,
layer_norm_eps=1e-05,
num_attention_heads=8,
num_hidden_layers=3,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
}
return inputs
def test_workflow_with_stable_diffusion(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, return_workflow=True)
image = output.images
image_slice = image[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
output.workflow.save_pretrained(tmpdirname)
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.load_workflow(tmpdirname)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
workflow_image_slice = image[0, -3:, -3:, -1]
self.assertTrue(np.allclose(image_slice, workflow_image_slice))
@is_staging_test
class WorkflowPushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()
repo_id = f"test-workflow-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def compare_workflow_values(self, repo_id: str, actual_workflow: dict):
local_path = hf_hub_download(repo_id=repo_id, filename=WORKFLOW_NAME, token=TOKEN)
with open(local_path) as f:
locally_loaded_workflow = json.load(f)
for k in actual_workflow:
assert actual_workflow[k] == locally_loaded_workflow[k]
def test_push_to_hub(self):
workflow = Workflow()
workflow.update({"prompt": "hey", "num_inference_steps": 25})
workflow.push_to_hub(self.repo_id, token=TOKEN)
self.compare_workflow_values(repo_id=f"{USER}/{self.repo_id}", actual_workflow=workflow)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
def test_push_to_hub_in_organization(self):
workflow = Workflow()
workflow.update({"prompt": "hey", "num_inference_steps": 25})
workflow.push_to_hub(self.org_repo_id, token=TOKEN)
self.compare_workflow_values(repo_id=self.org_repo_id, actual_workflow=workflow)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)