mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 06:54:32 +08:00
Compare commits
163 Commits
qwen-image
...
feat/workf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
601c506918 | ||
|
|
f8a5e172cf | ||
|
|
47e9219450 | ||
|
|
04d83c209d | ||
|
|
0d81b2dab4 | ||
|
|
e67ddf8d13 | ||
|
|
cdbbc7d5b7 | ||
|
|
81fb265a08 | ||
|
|
21c8c433a3 | ||
|
|
3479e5311d | ||
|
|
f8cad5dc4a | ||
|
|
d75f8a537c | ||
|
|
4b94889652 | ||
|
|
35a6538343 | ||
|
|
1b65ff770c | ||
|
|
857c65bb56 | ||
|
|
b5752ec4bf | ||
|
|
157405436b | ||
|
|
d69f3079a8 | ||
|
|
407e669fca | ||
|
|
5e00fcc153 | ||
|
|
1dd8cf5abe | ||
|
|
10739166e2 | ||
|
|
c1d8b882ee | ||
|
|
6697144a7d | ||
|
|
6af69f5639 | ||
|
|
c28ea5e6c6 | ||
|
|
cb8902394d | ||
|
|
5c8d5df564 | ||
|
|
474df7a6b1 | ||
|
|
047cede64a | ||
|
|
ce9e17c547 | ||
|
|
abdb4ffbc6 | ||
|
|
87174fc208 | ||
|
|
425e75bc79 | ||
|
|
f14cf8f4aa | ||
|
|
2f7725edf4 | ||
|
|
a11ce647de | ||
|
|
6f65f3ad3e | ||
|
|
cbc835b870 | ||
|
|
8550a86a17 | ||
|
|
477cc9a82a | ||
|
|
28c8e93179 | ||
|
|
3f90e07228 | ||
|
|
67f7757048 | ||
|
|
0e40d6ffd6 | ||
|
|
3eab48f883 | ||
|
|
b3a675288d | ||
|
|
9099f51c5e | ||
|
|
5dcd8c541e | ||
|
|
ac73c86610 | ||
|
|
9836b61fa8 | ||
|
|
1829b9485c | ||
|
|
6e514c2b6a | ||
|
|
b5888b4704 | ||
|
|
003220ba36 | ||
|
|
f221631d3c | ||
|
|
adcaba0a23 | ||
|
|
bd1f78e6cd | ||
|
|
ecfa79b673 | ||
|
|
ab1d58872b | ||
|
|
020b4a4ad7 | ||
|
|
e510c3d4d5 | ||
|
|
f0418e8896 | ||
|
|
c4eebd9c1a | ||
|
|
84de851116 | ||
|
|
eae28adf40 | ||
|
|
450198061e | ||
|
|
be3ff851c6 | ||
|
|
c3792ba3e0 | ||
|
|
7933f2ac18 | ||
|
|
1b7a7c27d3 | ||
|
|
4689d759fd | ||
|
|
d06cc7eb6f | ||
|
|
b5780adf46 | ||
|
|
d07d3f1642 | ||
|
|
e5a69ff497 | ||
|
|
43846e14a1 | ||
|
|
7fa7259ded | ||
|
|
b576a1dc47 | ||
|
|
9fc37d9dc7 | ||
|
|
0a298f55fc | ||
|
|
7382344fed | ||
|
|
93ce75f4e5 | ||
|
|
b42bcb86ea | ||
|
|
47fe2d0d2e | ||
|
|
6a59219e81 | ||
|
|
f8eff79b82 | ||
|
|
d01f2f678a | ||
|
|
ae0a268f8e | ||
|
|
fce889f19b | ||
|
|
01b3f64549 | ||
|
|
53c65e3d19 | ||
|
|
6d48af1a46 | ||
|
|
66d3fd6732 | ||
|
|
0cc943b757 | ||
|
|
41a74f8474 | ||
|
|
9a8b5f7cd8 | ||
|
|
ebf2addb86 | ||
|
|
03bfdff59a | ||
|
|
ff5cd58aa1 | ||
|
|
c1c11a6747 | ||
|
|
49e06fdd2c | ||
|
|
f874578a4d | ||
|
|
c0e1c6348f | ||
|
|
231e8314dd | ||
|
|
21e5bb65e5 | ||
|
|
74e766c0b8 | ||
|
|
452bf4fa05 | ||
|
|
55c47bc751 | ||
|
|
eaae2df25c | ||
|
|
d612b5435b | ||
|
|
1dc9854968 | ||
|
|
18a756f0ad | ||
|
|
4993c8ba63 | ||
|
|
ed9acd6426 | ||
|
|
a69e3d15c1 | ||
|
|
4731f65ed4 | ||
|
|
9adaa1739a | ||
|
|
af282b7f4b | ||
|
|
3d6637b65d | ||
|
|
800b7a0fda | ||
|
|
91c1c1f1f6 | ||
|
|
21d19bbc44 | ||
|
|
9d0bcd48f1 | ||
|
|
9ee8b0a070 | ||
|
|
b5fd337875 | ||
|
|
f08f40bde1 | ||
|
|
c5ff8cd943 | ||
|
|
319456049a | ||
|
|
f6c0878fc6 | ||
|
|
e590b73cc1 | ||
|
|
aa7839c1c7 | ||
|
|
fc609e308f | ||
|
|
7b85bfe3e5 | ||
|
|
eff03fd054 | ||
|
|
73dcc17ff1 | ||
|
|
b149800269 | ||
|
|
2d1cd20afe | ||
|
|
45c5656bad | ||
|
|
2b48d8572d | ||
|
|
e710121a9b | ||
|
|
50769e058b | ||
|
|
0bd97735dc | ||
|
|
930ca765f4 | ||
|
|
807c2ca13f | ||
|
|
ad725977cc | ||
|
|
97ae043f8a | ||
|
|
1ab81a6db4 | ||
|
|
a6a0277713 | ||
|
|
ba0b1e857c | ||
|
|
d8e6f38db4 | ||
|
|
ef94a008d2 | ||
|
|
29d0aa887c | ||
|
|
5f19b66d5a | ||
|
|
96c55d4c5a | ||
|
|
e3611e325b | ||
|
|
d5d31e0ae3 | ||
|
|
a62b77ff6e | ||
|
|
a8a1378987 | ||
|
|
e8e09e48ea | ||
|
|
ac295055ce | ||
|
|
43e4e841f9 |
@@ -19,6 +19,8 @@
|
|||||||
title: Train a diffusion model
|
title: Train a diffusion model
|
||||||
- local: tutorials/using_peft_for_inference
|
- local: tutorials/using_peft_for_inference
|
||||||
title: Inference with PEFT
|
title: Inference with PEFT
|
||||||
|
- local: tutorials/workflows
|
||||||
|
title: Working with workflows
|
||||||
title: Tutorials
|
title: Tutorials
|
||||||
- sections:
|
- sections:
|
||||||
- sections:
|
- sections:
|
||||||
@@ -178,6 +180,8 @@
|
|||||||
title: Logging
|
title: Logging
|
||||||
- local: api/outputs
|
- local: api/outputs
|
||||||
title: Outputs
|
title: Outputs
|
||||||
|
- local: api/workflows
|
||||||
|
title: Shareable workflows
|
||||||
title: Main Classes
|
title: Main Classes
|
||||||
- sections:
|
- sections:
|
||||||
- local: api/models/overview
|
- local: api/models/overview
|
||||||
|
|||||||
7
docs/source/en/api/workflows.md
Normal file
7
docs/source/en/api/workflows.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Shareable workflows
|
||||||
|
|
||||||
|
Workflows provide a simple mechanism to share your 🤗 Diffusers pipeline call arguments and scheduler configuration, making it easier to reproduce results.
|
||||||
|
|
||||||
|
## Workflow
|
||||||
|
|
||||||
|
[[autodoc]] workflow_utils.Workflow
|
||||||
333
docs/source/en/tutorials/workflows.md
Normal file
333
docs/source/en/tutorials/workflows.md
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Working with workflows
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
🧪 Workflow is experimental and its APIs can change in the future.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Workflows provide a simple mechanism to share your pipeline call arguments and scheduler configuration, making it easier to reproduce results.
|
||||||
|
|
||||||
|
## Serializing a workflow
|
||||||
|
|
||||||
|
A [`Workflow`] object provides all the argument values in the `__call__()` of a pipeline. Add `return_workflow=True` to return a `Workflow` object.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
outputs = pipeline(
|
||||||
|
"A painting of a horse",
|
||||||
|
num_inference_steps=15,
|
||||||
|
generator=torch.manual_seed(0),
|
||||||
|
return_workflow=True
|
||||||
|
)
|
||||||
|
workflow = outputs.workflow
|
||||||
|
```
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
It's mandatory to specify the `generator` when `return_workflow` is set to True.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
If you look at this specific workflow, you'll see values like the number of inference steps, guidance scale, and height and width as well as the scheduler details:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{'prompt': 'A painting of a horse',
|
||||||
|
'height': None,
|
||||||
|
'width': None,
|
||||||
|
'num_inference_steps': 15,
|
||||||
|
'guidance_scale': 7.5,
|
||||||
|
'negative_prompt': None,
|
||||||
|
'eta': 0.0,
|
||||||
|
'latents': None,
|
||||||
|
'prompt_embeds': None,
|
||||||
|
'negative_prompt_embeds': None,
|
||||||
|
'output_type': 'pil',
|
||||||
|
'return_dict': True,
|
||||||
|
'callback': None,
|
||||||
|
'callback_steps': 1,
|
||||||
|
'cross_attention_kwargs': None,
|
||||||
|
'guidance_rescale': 0.0,
|
||||||
|
'clip_skip': None,
|
||||||
|
'generator_seed': 0,
|
||||||
|
'generator_device': device(type='cpu'),
|
||||||
|
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
|
||||||
|
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
|
||||||
|
('beta_start', 0.00085),
|
||||||
|
('beta_end', 0.012),
|
||||||
|
('beta_schedule', 'scaled_linear'),
|
||||||
|
('trained_betas', None),
|
||||||
|
('skip_prk_steps', True),
|
||||||
|
('set_alpha_to_one', False),
|
||||||
|
('prediction_type', 'epsilon'),
|
||||||
|
('timestep_spacing', 'leading'),
|
||||||
|
('steps_offset', 1),
|
||||||
|
('_use_default_values', ['prediction_type', 'timestep_spacing']),
|
||||||
|
('_class_name', 'PNDMScheduler'),
|
||||||
|
('_diffusers_version', '0.6.0'),
|
||||||
|
('clip_sample', False)])}
|
||||||
|
```
|
||||||
|
|
||||||
|
Once you have generated a workflow object, you can serialize it with [`~Workflow.save_workflow`]:
|
||||||
|
|
||||||
|
```python
|
||||||
|
outputs.workflow.save_workflow("my-simple-workflow-sd")
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, your workflows are saved as `diffusion_workflow.json`, but you can give them a specific name with the `filename` argument:
|
||||||
|
|
||||||
|
```python
|
||||||
|
outputs.workflow.save_workflow("my-simple-workflow-sd", filename="my_workflow.json")
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also set `push_to_hub=True` in [`~Workflow.save_workflow`] to directly push the workflow object to the Hub.
|
||||||
|
|
||||||
|
## Loading a workflow
|
||||||
|
|
||||||
|
You can load a workflow in a pipeline with [`~DiffusionPipeline.load_workflow`]:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
|
||||||
|
```
|
||||||
|
|
||||||
|
Once the pipeline is loaded with the desired workflow, it's ready to be called:
|
||||||
|
|
||||||
|
```python
|
||||||
|
image = pipeline().images[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, while loading a workflow, the scheduler of the underlying pipeline from the workflow isn't modified but you can change it by adding `load_scheduler=True`:
|
||||||
|
|
||||||
|
```
|
||||||
|
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd", load_scheduler=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
This is particularly useful if you have changed the scheduler after loading a pipeline.
|
||||||
|
|
||||||
|
You can also override the pipeline call arguments. For example, to add a `negative_prompt`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
image = pipeline(negative_prompt="bad quality").images[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
Loading from a workflow is possible by specifying the `filename` argument inside the [`DiffusionPipeline.load_workflow`] method.
|
||||||
|
|
||||||
|
A workflow doesn't necessarily have to be used with the same pipeline that generated it. You can use it with a different pipeline too:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
|
||||||
|
image = pipeline().images[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
However, make sure to thoroughly inspect the values you are calling the pipeline with, in this case.
|
||||||
|
|
||||||
|
Loading from a local workflow is also possible:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
pipeline.load_workflow("path_to_local_dir")
|
||||||
|
image = pipeline().images[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively, if you want to load a workflow file and populate the pipeline arguments manually:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
with open("path_to_workflow_file.json") as f:
|
||||||
|
workflow = json.load(f)
|
||||||
|
|
||||||
|
pipeline.load_workflow(workflow)
|
||||||
|
images = pipeline().images[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Unsupported serialization types
|
||||||
|
|
||||||
|
Image-to-image pipelines like [`StableDiffusionControlNetPipeline`] accept one or more images in their `call` method. Currently, workflows don't support serializing `call` arguments that are of type `PIL.Image.Image` or `List[PIL.Image.Image]`. To make those pipelines work with workflows, you need to pass the images manually.
|
||||||
|
|
||||||
|
Let's say you generated the workflow below:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
|
||||||
|
from diffusers.utils import load_image
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# download an image
|
||||||
|
image = load_image(
|
||||||
|
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
|
||||||
|
)
|
||||||
|
image = np.array(image)
|
||||||
|
|
||||||
|
# get canny image
|
||||||
|
image = cv2.Canny(image, 100, 200)
|
||||||
|
image = image[:, :, None]
|
||||||
|
image = np.concatenate([image, image, image], axis=2)
|
||||||
|
canny_image = Image.fromarray(image)
|
||||||
|
|
||||||
|
# load control net and stable diffusion v1-5
|
||||||
|
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||||
|
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
# generate image
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
outputs = pipe(
|
||||||
|
prompt="futuristic-looking office",
|
||||||
|
image=canny_image,
|
||||||
|
num_inference_steps=20,
|
||||||
|
generator=generator,
|
||||||
|
return_workflow=True
|
||||||
|
)
|
||||||
|
workflow = outputs.workflow
|
||||||
|
```
|
||||||
|
|
||||||
|
If you look at the workflow, you'll see the image that was passed to the pipeline isn't included:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{'prompt': 'futuristic-looking office',
|
||||||
|
'height': None,
|
||||||
|
'width': None,
|
||||||
|
'num_inference_steps': 20,
|
||||||
|
'guidance_scale': 7.5,
|
||||||
|
'negative_prompt': None,
|
||||||
|
'eta': 0.0,
|
||||||
|
'latents': None,
|
||||||
|
'prompt_embeds': None,
|
||||||
|
'negative_prompt_embeds': None,
|
||||||
|
'output_type': 'pil',
|
||||||
|
'return_dict': True,
|
||||||
|
'callback': None,
|
||||||
|
'callback_steps': 1,
|
||||||
|
'cross_attention_kwargs': None,
|
||||||
|
'controlnet_conditioning_scale': 1.0,
|
||||||
|
'guess_mode': False,
|
||||||
|
'control_guidance_start': 0.0,
|
||||||
|
'control_guidance_end': 1.0,
|
||||||
|
'clip_skip': None,
|
||||||
|
'generator_seed': 0,
|
||||||
|
'generator_device': 'cpu',
|
||||||
|
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
|
||||||
|
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
|
||||||
|
('beta_start', 0.00085),
|
||||||
|
('beta_end', 0.012),
|
||||||
|
('beta_schedule', 'scaled_linear'),
|
||||||
|
('trained_betas', None),
|
||||||
|
('solver_order', 2),
|
||||||
|
('prediction_type', 'epsilon'),
|
||||||
|
('thresholding', False),
|
||||||
|
('dynamic_thresholding_ratio', 0.995),
|
||||||
|
('sample_max_value', 1.0),
|
||||||
|
('predict_x0', True),
|
||||||
|
('solver_type', 'bh2'),
|
||||||
|
('lower_order_final', True),
|
||||||
|
('disable_corrector', []),
|
||||||
|
('solver_p', None),
|
||||||
|
('use_karras_sigmas', False),
|
||||||
|
('timestep_spacing', 'linspace'),
|
||||||
|
('steps_offset', 1),
|
||||||
|
('_use_default_values',
|
||||||
|
['lower_order_final',
|
||||||
|
'sample_max_value',
|
||||||
|
'solver_p',
|
||||||
|
'dynamic_thresholding_ratio',
|
||||||
|
'thresholding',
|
||||||
|
'solver_type',
|
||||||
|
'prediction_type',
|
||||||
|
'predict_x0',
|
||||||
|
'use_karras_sigmas',
|
||||||
|
'disable_corrector',
|
||||||
|
'timestep_spacing',
|
||||||
|
'solver_order']),
|
||||||
|
('skip_prk_steps', True),
|
||||||
|
('set_alpha_to_one', False),
|
||||||
|
('_class_name', 'PNDMScheduler'),
|
||||||
|
('_diffusers_version', '0.6.0'),
|
||||||
|
('clip_sample', False)])}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Let's serialize the workflow and reload the pipeline to see what happens when you try to use it.
|
||||||
|
|
||||||
|
```python
|
||||||
|
workflow.save_workflow("my-simple-workflow-sd", filename="controlnet_simple.json", push_to_hub=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then load the workflow into [`StableDiffusionControlNetPipeline`]:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# load control net and stable diffusion v1-5
|
||||||
|
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||||
|
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
pipe.load_workflow("sayakpaul/my-simple-workflow-sd", filename="controlnet_simple.json")
|
||||||
|
```
|
||||||
|
|
||||||
|
If you try to generate an image now, it'll return the following error:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
TypeError: image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is <class 'NoneType'>
|
||||||
|
```
|
||||||
|
|
||||||
|
To resolve the error, manually pass the conditioning image `canny_image`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
image = pipe(image=canny_image).images[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
Other unsupported serialization types include:
|
||||||
|
|
||||||
|
* LoRA checkpoints: any information from LoRA checkpoints that might be loaded into a pipeline isn't serialized. Workflows generated from pipelines loaded with a LoRA checkpoint should be handled cautiously! You should ensure the LoRA checkpoint is loaded into the pipeline first before loading the corresponding workflow.
|
||||||
|
* Call arguments including the following types: `torch.Tensor`, `np.ndarray`, `Callable`, `PIL.Image.Image`, and `List[PIL.Image.Image]`.
|
||||||
@@ -752,6 +752,7 @@ class StableDiffusionControlNetPipeline(
|
|||||||
guess_mode: bool = False,
|
guess_mode: bool = False,
|
||||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||||
|
return_workflow: bool = False,
|
||||||
clip_skip: Optional[int] = None,
|
clip_skip: Optional[int] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -824,6 +825,8 @@ class StableDiffusionControlNetPipeline(
|
|||||||
The percentage of total steps at which the ControlNet starts applying.
|
The percentage of total steps at which the ControlNet starts applying.
|
||||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||||
The percentage of total steps at which the ControlNet stops applying.
|
The percentage of total steps at which the ControlNet stops applying.
|
||||||
|
return_workflow (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to return used pipeline call arguments.
|
||||||
clip_skip (`int`, *optional*):
|
clip_skip (`int`, *optional*):
|
||||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||||
@@ -837,6 +840,14 @@ class StableDiffusionControlNetPipeline(
|
|||||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||||
"not-safe-for-work" (nsfw) content.
|
"not-safe-for-work" (nsfw) content.
|
||||||
"""
|
"""
|
||||||
|
# We do this first to capture the "True" call values. If we do this at a later point in time,
|
||||||
|
# we cannot ensure that the call values weren't changed during the process.
|
||||||
|
workflow = None
|
||||||
|
if return_workflow:
|
||||||
|
if generator is None:
|
||||||
|
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
|
||||||
|
workflow = self.populate_workflow_from_pipeline()
|
||||||
|
|
||||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||||
|
|
||||||
# align format for control guidance
|
# align format for control guidance
|
||||||
@@ -1075,6 +1086,11 @@ class StableDiffusionControlNetPipeline(
|
|||||||
self.maybe_free_model_hooks()
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (image, has_nsfw_concept)
|
outputs = (image, has_nsfw_concept)
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
if return_workflow:
|
||||||
|
outputs += (workflow,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
@@ -54,7 +55,9 @@ from ..utils import (
|
|||||||
logging,
|
logging,
|
||||||
numpy_to_pil,
|
numpy_to_pil,
|
||||||
)
|
)
|
||||||
|
from ..utils.constants import WORKFLOW_NAME
|
||||||
from ..utils.torch_utils import is_compiled_module
|
from ..utils.torch_utils import is_compiled_module
|
||||||
|
from ..workflow_utils import _NON_CALL_ARGUMENTS, Workflow
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
@@ -64,6 +67,7 @@ if is_transformers_available():
|
|||||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
|
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -2075,3 +2079,117 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
for module in modules:
|
for module in modules:
|
||||||
module.set_attention_slice(slice_size)
|
module.set_attention_slice(slice_size)
|
||||||
|
|
||||||
|
def populate_workflow_from_pipeline(self) -> Dict:
|
||||||
|
r"""Populates the call arguments in a dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`Workflow`]: A dictionary containing the details of the pipeline call arguments and (optionally) LoRA
|
||||||
|
checkpoint details.
|
||||||
|
"""
|
||||||
|
# A `Workflow` object is an extended Python dictionary. So, all regular dictionary methods
|
||||||
|
# apply to it.
|
||||||
|
workflow = Workflow()
|
||||||
|
|
||||||
|
signature = inspect.signature(self.__call__)
|
||||||
|
argument_names = [param.name for param in signature.parameters.values()]
|
||||||
|
call_arg_values = inspect.getargvalues(inspect.currentframe().f_back).locals
|
||||||
|
|
||||||
|
# Populate call arguments.
|
||||||
|
call_arguments = {
|
||||||
|
arg: call_arg_values[arg]
|
||||||
|
for arg in argument_names
|
||||||
|
if arg != "return_workflow"
|
||||||
|
and "image" not in arg
|
||||||
|
and not isinstance(call_arg_values[arg], (torch.Tensor, np.ndarray, Callable))
|
||||||
|
}
|
||||||
|
workflow.update(call_arguments)
|
||||||
|
|
||||||
|
# Handle generator device and seed.
|
||||||
|
generator = workflow["generator"]
|
||||||
|
if isinstance(generator, list):
|
||||||
|
for g in generator:
|
||||||
|
if "generator_seed" not in workflow:
|
||||||
|
workflow.update({"generator_seed": [g.initial_seed()]})
|
||||||
|
workflow.update({"generator_device": [str(g.device)]})
|
||||||
|
workflow.update({"generator_state": g.get_state().numpy().tolist()})
|
||||||
|
else:
|
||||||
|
workflow["generator_seed"].append(g.initial_seed())
|
||||||
|
workflow["generator_device"].append(g.device)
|
||||||
|
workflow["generator_state"].append(g.get_state().numpy().tolist())
|
||||||
|
else:
|
||||||
|
workflow.update({"generator_seed": generator.initial_seed()})
|
||||||
|
workflow.update({"generator_device": str(generator.device)})
|
||||||
|
workflow.update({"generator_state": generator.get_state().numpy().tolist()})
|
||||||
|
|
||||||
|
workflow.pop("generator")
|
||||||
|
|
||||||
|
# Handle pipeline-level things.
|
||||||
|
if hasattr(self, "config") and hasattr(self.config, "_name_or_path"):
|
||||||
|
pipeline_config_name_or_path = self.config._name_or_path
|
||||||
|
else:
|
||||||
|
pipeline_config_name_or_path = None
|
||||||
|
workflow["_name_or_path"] = pipeline_config_name_or_path
|
||||||
|
workflow["scheduler_config"] = self.scheduler.config
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def load_workflow(
|
||||||
|
self,
|
||||||
|
workflow_id_or_path: Union[str, dict],
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
):
|
||||||
|
r"""Loads a workflow from the Hub or from a local path. Also patches the pipeline call arguments with values from the
|
||||||
|
workflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_id_or_path (`str` or `dict`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the workflow id (for example `sayakpaul/sdxl-workflow`) of a workflow hosted on the
|
||||||
|
Hub.
|
||||||
|
- A path to a directory (for example `./my_workflow_directory`) containing the workflow file with
|
||||||
|
[`Workflow.save_workflow`] or [`Workflow.push_to_hub`].
|
||||||
|
- A Python dictionary.
|
||||||
|
|
||||||
|
filename (`str`, *optional*):
|
||||||
|
Optional name of the workflow file to load. Especially useful when working with multiple workflow
|
||||||
|
files.
|
||||||
|
"""
|
||||||
|
filename = filename or WORKFLOW_NAME
|
||||||
|
|
||||||
|
# Load workflow.
|
||||||
|
if not isinstance(workflow_id_or_path, dict):
|
||||||
|
if os.path.isdir(workflow_id_or_path):
|
||||||
|
workflow_filepath = os.path.join(workflow_id_or_path, filename)
|
||||||
|
elif os.path.isfile(workflow_id_or_path):
|
||||||
|
workflow_filepath = workflow_id_or_path
|
||||||
|
else:
|
||||||
|
workflow_filepath = hf_hub_download(repo_id=workflow_id_or_path, filename=filename)
|
||||||
|
workflow = self._dict_from_json_file(workflow_filepath)
|
||||||
|
else:
|
||||||
|
workflow = workflow_id_or_path
|
||||||
|
|
||||||
|
# We make a copy of the original workflow and operate on it.
|
||||||
|
workflow_copy = dict(workflow.items())
|
||||||
|
|
||||||
|
# Handle generator.
|
||||||
|
seed = workflow_copy.pop("generator_seed")
|
||||||
|
device = workflow_copy.pop("generator_device", "cpu")
|
||||||
|
last_known_state = workflow_copy.pop("generator_state")
|
||||||
|
if isinstance(seed, list):
|
||||||
|
generator = [
|
||||||
|
torch.Generator(device=d).manual_seed(s).set_state(torch.from_numpy(np.array(lst)).byte())
|
||||||
|
for s, d, lst in zip(seed, device, last_known_state)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
last_known_state = torch.from_numpy(np.array(last_known_state)).byte()
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed).set_state(last_known_state)
|
||||||
|
workflow_copy.update({"generator": generator})
|
||||||
|
|
||||||
|
# Handle non-call arguments.
|
||||||
|
final_call_args = {k: v for k, v in workflow_copy.items() if k not in _NON_CALL_ARGUMENTS}
|
||||||
|
|
||||||
|
# Handle the call here.
|
||||||
|
partial_call = partial(self.__call__, **final_call_args)
|
||||||
|
setattr(self.__class__, "__call__", partial_call)
|
||||||
|
|||||||
@@ -19,10 +19,13 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
|||||||
nsfw_content_detected (`List[bool]`)
|
nsfw_content_detected (`List[bool]`)
|
||||||
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
||||||
`None` if safety checking could not be performed.
|
`None` if safety checking could not be performed.
|
||||||
|
workflow (`dict`):
|
||||||
|
Dictionary containing pipeline component configurations and call arguments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||||
nsfw_content_detected: Optional[List[bool]]
|
nsfw_content_detected: Optional[List[bool]]
|
||||||
|
workflow: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
|||||||
@@ -623,6 +623,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
guidance_rescale: float = 0.0,
|
guidance_rescale: float = 0.0,
|
||||||
|
return_workflow: bool = False,
|
||||||
clip_skip: Optional[int] = None,
|
clip_skip: Optional[int] = None,
|
||||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
@@ -677,6 +678,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
||||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
||||||
using zero terminal SNR.
|
using zero terminal SNR.
|
||||||
|
return_workflow (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to return used pipeline call arguments.
|
||||||
clip_skip (`int`, *optional*):
|
clip_skip (`int`, *optional*):
|
||||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||||
@@ -699,6 +702,13 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||||
"not-safe-for-work" (nsfw) content.
|
"not-safe-for-work" (nsfw) content.
|
||||||
"""
|
"""
|
||||||
|
# We do this first to capture the "True" call values. If we do this at a later point in time,
|
||||||
|
# we cannot ensure that the call values weren't changed during the process.
|
||||||
|
workflow = None
|
||||||
|
if return_workflow:
|
||||||
|
if generator is None:
|
||||||
|
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
|
||||||
|
workflow = self.populate_workflow_from_pipeline()
|
||||||
|
|
||||||
callback = kwargs.pop("callback", None)
|
callback = kwargs.pop("callback", None)
|
||||||
callback_steps = kwargs.pop("callback_steps", None)
|
callback_steps = kwargs.pop("callback_steps", None)
|
||||||
@@ -855,6 +865,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
self.maybe_free_model_hooks()
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (image, has_nsfw_concept)
|
outputs = (image, has_nsfw_concept)
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
if return_workflow:
|
||||||
|
outputs += (workflow,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@@ -32,11 +33,13 @@ FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
|||||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||||
|
WORKFLOW_NAME = "diffusion_workflow.json"
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||||
DIFFUSERS_CACHE = default_cache_path
|
DIFFUSERS_CACHE = default_cache_path
|
||||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||||
|
MAX_SEED = np.iinfo(np.int32).max
|
||||||
|
|
||||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||||
|
|||||||
161
src/diffusers/workflow_utils.py
Normal file
161
src/diffusers/workflow_utils.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module for managing workflows."""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import PosixPath
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import create_repo
|
||||||
|
|
||||||
|
from . import __version__
|
||||||
|
from .utils import PushToHubMixin, logging
|
||||||
|
from .utils.constants import WORKFLOW_NAME
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
_NON_CALL_ARGUMENTS = {"_name_or_path", "scheduler_config", "_class_name", "_diffusers_version"}
|
||||||
|
|
||||||
|
|
||||||
|
class Workflow(dict, PushToHubMixin):
|
||||||
|
"""Class sub-classing from native Python dict to have support for interacting with the Hub."""
|
||||||
|
|
||||||
|
config_name = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.config_name = WORKFLOW_NAME
|
||||||
|
self._internal_dict = {}
|
||||||
|
|
||||||
|
def __setitem__(self, __key, __value):
|
||||||
|
self._internal_dict[__key] = __value
|
||||||
|
return super().__setitem__(__key, __value)
|
||||||
|
|
||||||
|
def update(self, __m, **kwargs):
|
||||||
|
self._internal_dict.update(__m, **kwargs)
|
||||||
|
super().update(__m, **kwargs)
|
||||||
|
|
||||||
|
def pop(self, key, *args):
|
||||||
|
self._internal_dict.pop(key, *args)
|
||||||
|
return super().pop(key, *args)
|
||||||
|
|
||||||
|
# Copied from diffusers.configuration_utils.ConfigMixin.to_json_string
|
||||||
|
def to_json_string(self) -> str:
|
||||||
|
"""
|
||||||
|
Serializes the configuration instance to a JSON string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`:
|
||||||
|
String containing all the attributes that make up the configuration instance in JSON format.
|
||||||
|
"""
|
||||||
|
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
||||||
|
config_dict["_class_name"] = self.__class__.__name__
|
||||||
|
config_dict["_diffusers_version"] = __version__
|
||||||
|
|
||||||
|
def to_json_saveable(value):
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
value = value.tolist()
|
||||||
|
elif isinstance(value, PosixPath):
|
||||||
|
value = str(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
||||||
|
# Don't save "_ignore_files" or "_use_default_values"
|
||||||
|
config_dict.pop("_ignore_files", None)
|
||||||
|
config_dict.pop("_use_default_values", None)
|
||||||
|
|
||||||
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
def save_workflow(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
push_to_hub: bool = False,
|
||||||
|
filename: str = WORKFLOW_NAME,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Saves a workflow to a directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_directory (`str` or `os.PathLike`):
|
||||||
|
Directory where the workflow JSON file will be saved (will be created if it does not exist).
|
||||||
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||||
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||||
|
namespace).
|
||||||
|
filename (`str`, *optional*, defaults to `workflow.json`):
|
||||||
|
Optional filename to use to serialize the workflow JSON.
|
||||||
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
|
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
|
"""
|
||||||
|
self.config_name = filename
|
||||||
|
|
||||||
|
if os.path.isfile(save_directory):
|
||||||
|
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
|
|
||||||
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
|
output_config_file = os.path.join(save_directory, self.config_name)
|
||||||
|
with open(output_config_file, "w", encoding="utf-8") as writer:
|
||||||
|
writer.write(self.to_json_string())
|
||||||
|
logger.info(f"Configuration saved in {output_config_file}")
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
commit_message = kwargs.pop("commit_message", None)
|
||||||
|
private = kwargs.pop("private", False)
|
||||||
|
create_pr = kwargs.pop("create_pr", False)
|
||||||
|
token = kwargs.pop("token", None)
|
||||||
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||||
|
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||||
|
|
||||||
|
self._upload_folder(
|
||||||
|
save_directory,
|
||||||
|
repo_id,
|
||||||
|
token=token,
|
||||||
|
commit_message=commit_message,
|
||||||
|
create_pr=create_pr,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
push_to_hub: bool = False,
|
||||||
|
filename: str = WORKFLOW_NAME,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Saves a workflow to a directory. This internally calls [`Workflow.save_workflow`], This method exists to have
|
||||||
|
feature parity with [`PushToHubMixin.push_to_hub`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_directory (`str` or `os.PathLike`):
|
||||||
|
Directory where the workflow JSON file will be saved (will be created if it does not exist).
|
||||||
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||||
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||||
|
namespace).
|
||||||
|
filename (`str`, *optional*, defaults to `workflow.json`):
|
||||||
|
Optional filename to use to serialize the workflow JSON.
|
||||||
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
|
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
|
"""
|
||||||
|
self.save_workflow(
|
||||||
|
save_directory=save_directory,
|
||||||
|
push_to_hub=push_to_hub,
|
||||||
|
filename=filename,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
171
tests/others/test_workflows.py
Normal file
171
tests/others/test_workflows.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import delete_repo, hf_hub_download
|
||||||
|
from test_utils import TOKEN, USER, is_staging_test
|
||||||
|
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
|
DDIMScheduler,
|
||||||
|
StableDiffusionPipeline,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
)
|
||||||
|
from diffusers.utils.constants import WORKFLOW_NAME
|
||||||
|
from diffusers.utils.testing_utils import torch_device
|
||||||
|
from diffusers.workflow_utils import Workflow
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowFastTests(unittest.TestCase):
|
||||||
|
def get_dummy_components(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
unet = UNet2DConditionModel(
|
||||||
|
block_out_channels=(4, 8),
|
||||||
|
layers_per_block=1,
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||||
|
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||||
|
cross_attention_dim=32,
|
||||||
|
norm_num_groups=2,
|
||||||
|
)
|
||||||
|
scheduler = DDIMScheduler(
|
||||||
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
clip_sample=False,
|
||||||
|
set_alpha_to_one=False,
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vae = AutoencoderKL(
|
||||||
|
block_out_channels=[4, 8],
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||||
|
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||||
|
latent_channels=4,
|
||||||
|
norm_num_groups=2,
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
text_encoder_config = CLIPTextConfig(
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=64,
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
num_attention_heads=8,
|
||||||
|
num_hidden_layers=3,
|
||||||
|
pad_token_id=1,
|
||||||
|
vocab_size=1000,
|
||||||
|
)
|
||||||
|
text_encoder = CLIPTextModel(text_encoder_config)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
components = {
|
||||||
|
"unet": unet,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"vae": vae,
|
||||||
|
"text_encoder": text_encoder,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"safety_checker": None,
|
||||||
|
"feature_extractor": None,
|
||||||
|
}
|
||||||
|
return components
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
|
if str(device).startswith("mps"):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 6.0,
|
||||||
|
"output_type": "np",
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_workflow_with_stable_diffusion(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
sd_pipe = StableDiffusionPipeline(**components)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
output = sd_pipe(**inputs, return_workflow=True)
|
||||||
|
image = output.images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
output.workflow.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
sd_pipe = StableDiffusionPipeline(**components)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
sd_pipe.load_workflow(tmpdirname)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
output = sd_pipe(**inputs)
|
||||||
|
image = output.images
|
||||||
|
workflow_image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(image_slice, workflow_image_slice))
|
||||||
|
|
||||||
|
|
||||||
|
@is_staging_test
|
||||||
|
class WorkflowPushToHubTester(unittest.TestCase):
|
||||||
|
identifier = uuid.uuid4()
|
||||||
|
repo_id = f"test-workflow-{identifier}"
|
||||||
|
org_repo_id = f"valid_org/{repo_id}-org"
|
||||||
|
|
||||||
|
def compare_workflow_values(self, repo_id: str, actual_workflow: dict):
|
||||||
|
local_path = hf_hub_download(repo_id=repo_id, filename=WORKFLOW_NAME, token=TOKEN)
|
||||||
|
with open(local_path) as f:
|
||||||
|
locally_loaded_workflow = json.load(f)
|
||||||
|
for k in actual_workflow:
|
||||||
|
assert actual_workflow[k] == locally_loaded_workflow[k]
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
workflow = Workflow()
|
||||||
|
workflow.update({"prompt": "hey", "num_inference_steps": 25})
|
||||||
|
|
||||||
|
workflow.push_to_hub(self.repo_id, token=TOKEN)
|
||||||
|
self.compare_workflow_values(repo_id=f"{USER}/{self.repo_id}", actual_workflow=workflow)
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
workflow = Workflow()
|
||||||
|
workflow.update({"prompt": "hey", "num_inference_steps": 25})
|
||||||
|
|
||||||
|
workflow.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||||
|
self.compare_workflow_values(repo_id=self.org_repo_id, actual_workflow=workflow)
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||||
Reference in New Issue
Block a user