mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
163 Commits
fix-model-
...
feat/workf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
601c506918 | ||
|
|
f8a5e172cf | ||
|
|
47e9219450 | ||
|
|
04d83c209d | ||
|
|
0d81b2dab4 | ||
|
|
e67ddf8d13 | ||
|
|
cdbbc7d5b7 | ||
|
|
81fb265a08 | ||
|
|
21c8c433a3 | ||
|
|
3479e5311d | ||
|
|
f8cad5dc4a | ||
|
|
d75f8a537c | ||
|
|
4b94889652 | ||
|
|
35a6538343 | ||
|
|
1b65ff770c | ||
|
|
857c65bb56 | ||
|
|
b5752ec4bf | ||
|
|
157405436b | ||
|
|
d69f3079a8 | ||
|
|
407e669fca | ||
|
|
5e00fcc153 | ||
|
|
1dd8cf5abe | ||
|
|
10739166e2 | ||
|
|
c1d8b882ee | ||
|
|
6697144a7d | ||
|
|
6af69f5639 | ||
|
|
c28ea5e6c6 | ||
|
|
cb8902394d | ||
|
|
5c8d5df564 | ||
|
|
474df7a6b1 | ||
|
|
047cede64a | ||
|
|
ce9e17c547 | ||
|
|
abdb4ffbc6 | ||
|
|
87174fc208 | ||
|
|
425e75bc79 | ||
|
|
f14cf8f4aa | ||
|
|
2f7725edf4 | ||
|
|
a11ce647de | ||
|
|
6f65f3ad3e | ||
|
|
cbc835b870 | ||
|
|
8550a86a17 | ||
|
|
477cc9a82a | ||
|
|
28c8e93179 | ||
|
|
3f90e07228 | ||
|
|
67f7757048 | ||
|
|
0e40d6ffd6 | ||
|
|
3eab48f883 | ||
|
|
b3a675288d | ||
|
|
9099f51c5e | ||
|
|
5dcd8c541e | ||
|
|
ac73c86610 | ||
|
|
9836b61fa8 | ||
|
|
1829b9485c | ||
|
|
6e514c2b6a | ||
|
|
b5888b4704 | ||
|
|
003220ba36 | ||
|
|
f221631d3c | ||
|
|
adcaba0a23 | ||
|
|
bd1f78e6cd | ||
|
|
ecfa79b673 | ||
|
|
ab1d58872b | ||
|
|
020b4a4ad7 | ||
|
|
e510c3d4d5 | ||
|
|
f0418e8896 | ||
|
|
c4eebd9c1a | ||
|
|
84de851116 | ||
|
|
eae28adf40 | ||
|
|
450198061e | ||
|
|
be3ff851c6 | ||
|
|
c3792ba3e0 | ||
|
|
7933f2ac18 | ||
|
|
1b7a7c27d3 | ||
|
|
4689d759fd | ||
|
|
d06cc7eb6f | ||
|
|
b5780adf46 | ||
|
|
d07d3f1642 | ||
|
|
e5a69ff497 | ||
|
|
43846e14a1 | ||
|
|
7fa7259ded | ||
|
|
b576a1dc47 | ||
|
|
9fc37d9dc7 | ||
|
|
0a298f55fc | ||
|
|
7382344fed | ||
|
|
93ce75f4e5 | ||
|
|
b42bcb86ea | ||
|
|
47fe2d0d2e | ||
|
|
6a59219e81 | ||
|
|
f8eff79b82 | ||
|
|
d01f2f678a | ||
|
|
ae0a268f8e | ||
|
|
fce889f19b | ||
|
|
01b3f64549 | ||
|
|
53c65e3d19 | ||
|
|
6d48af1a46 | ||
|
|
66d3fd6732 | ||
|
|
0cc943b757 | ||
|
|
41a74f8474 | ||
|
|
9a8b5f7cd8 | ||
|
|
ebf2addb86 | ||
|
|
03bfdff59a | ||
|
|
ff5cd58aa1 | ||
|
|
c1c11a6747 | ||
|
|
49e06fdd2c | ||
|
|
f874578a4d | ||
|
|
c0e1c6348f | ||
|
|
231e8314dd | ||
|
|
21e5bb65e5 | ||
|
|
74e766c0b8 | ||
|
|
452bf4fa05 | ||
|
|
55c47bc751 | ||
|
|
eaae2df25c | ||
|
|
d612b5435b | ||
|
|
1dc9854968 | ||
|
|
18a756f0ad | ||
|
|
4993c8ba63 | ||
|
|
ed9acd6426 | ||
|
|
a69e3d15c1 | ||
|
|
4731f65ed4 | ||
|
|
9adaa1739a | ||
|
|
af282b7f4b | ||
|
|
3d6637b65d | ||
|
|
800b7a0fda | ||
|
|
91c1c1f1f6 | ||
|
|
21d19bbc44 | ||
|
|
9d0bcd48f1 | ||
|
|
9ee8b0a070 | ||
|
|
b5fd337875 | ||
|
|
f08f40bde1 | ||
|
|
c5ff8cd943 | ||
|
|
319456049a | ||
|
|
f6c0878fc6 | ||
|
|
e590b73cc1 | ||
|
|
aa7839c1c7 | ||
|
|
fc609e308f | ||
|
|
7b85bfe3e5 | ||
|
|
eff03fd054 | ||
|
|
73dcc17ff1 | ||
|
|
b149800269 | ||
|
|
2d1cd20afe | ||
|
|
45c5656bad | ||
|
|
2b48d8572d | ||
|
|
e710121a9b | ||
|
|
50769e058b | ||
|
|
0bd97735dc | ||
|
|
930ca765f4 | ||
|
|
807c2ca13f | ||
|
|
ad725977cc | ||
|
|
97ae043f8a | ||
|
|
1ab81a6db4 | ||
|
|
a6a0277713 | ||
|
|
ba0b1e857c | ||
|
|
d8e6f38db4 | ||
|
|
ef94a008d2 | ||
|
|
29d0aa887c | ||
|
|
5f19b66d5a | ||
|
|
96c55d4c5a | ||
|
|
e3611e325b | ||
|
|
d5d31e0ae3 | ||
|
|
a62b77ff6e | ||
|
|
a8a1378987 | ||
|
|
e8e09e48ea | ||
|
|
ac295055ce | ||
|
|
43e4e841f9 |
@@ -19,6 +19,8 @@
|
||||
title: Train a diffusion model
|
||||
- local: tutorials/using_peft_for_inference
|
||||
title: Inference with PEFT
|
||||
- local: tutorials/workflows
|
||||
title: Working with workflows
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- sections:
|
||||
@@ -178,6 +180,8 @@
|
||||
title: Logging
|
||||
- local: api/outputs
|
||||
title: Outputs
|
||||
- local: api/workflows
|
||||
title: Shareable workflows
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- local: api/models/overview
|
||||
|
||||
7
docs/source/en/api/workflows.md
Normal file
7
docs/source/en/api/workflows.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Shareable workflows
|
||||
|
||||
Workflows provide a simple mechanism to share your 🤗 Diffusers pipeline call arguments and scheduler configuration, making it easier to reproduce results.
|
||||
|
||||
## Workflow
|
||||
|
||||
[[autodoc]] workflow_utils.Workflow
|
||||
333
docs/source/en/tutorials/workflows.md
Normal file
333
docs/source/en/tutorials/workflows.md
Normal file
@@ -0,0 +1,333 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Working with workflows
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 Workflow is experimental and its APIs can change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Workflows provide a simple mechanism to share your pipeline call arguments and scheduler configuration, making it easier to reproduce results.
|
||||
|
||||
## Serializing a workflow
|
||||
|
||||
A [`Workflow`] object provides all the argument values in the `__call__()` of a pipeline. Add `return_workflow=True` to return a `Workflow` object.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
|
||||
).to("cuda")
|
||||
|
||||
outputs = pipeline(
|
||||
"A painting of a horse",
|
||||
num_inference_steps=15,
|
||||
generator=torch.manual_seed(0),
|
||||
return_workflow=True
|
||||
)
|
||||
workflow = outputs.workflow
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
It's mandatory to specify the `generator` when `return_workflow` is set to True.
|
||||
|
||||
</Tip>
|
||||
|
||||
If you look at this specific workflow, you'll see values like the number of inference steps, guidance scale, and height and width as well as the scheduler details:
|
||||
|
||||
```bash
|
||||
{'prompt': 'A painting of a horse',
|
||||
'height': None,
|
||||
'width': None,
|
||||
'num_inference_steps': 15,
|
||||
'guidance_scale': 7.5,
|
||||
'negative_prompt': None,
|
||||
'eta': 0.0,
|
||||
'latents': None,
|
||||
'prompt_embeds': None,
|
||||
'negative_prompt_embeds': None,
|
||||
'output_type': 'pil',
|
||||
'return_dict': True,
|
||||
'callback': None,
|
||||
'callback_steps': 1,
|
||||
'cross_attention_kwargs': None,
|
||||
'guidance_rescale': 0.0,
|
||||
'clip_skip': None,
|
||||
'generator_seed': 0,
|
||||
'generator_device': device(type='cpu'),
|
||||
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
|
||||
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
|
||||
('beta_start', 0.00085),
|
||||
('beta_end', 0.012),
|
||||
('beta_schedule', 'scaled_linear'),
|
||||
('trained_betas', None),
|
||||
('skip_prk_steps', True),
|
||||
('set_alpha_to_one', False),
|
||||
('prediction_type', 'epsilon'),
|
||||
('timestep_spacing', 'leading'),
|
||||
('steps_offset', 1),
|
||||
('_use_default_values', ['prediction_type', 'timestep_spacing']),
|
||||
('_class_name', 'PNDMScheduler'),
|
||||
('_diffusers_version', '0.6.0'),
|
||||
('clip_sample', False)])}
|
||||
```
|
||||
|
||||
Once you have generated a workflow object, you can serialize it with [`~Workflow.save_workflow`]:
|
||||
|
||||
```python
|
||||
outputs.workflow.save_workflow("my-simple-workflow-sd")
|
||||
```
|
||||
|
||||
By default, your workflows are saved as `diffusion_workflow.json`, but you can give them a specific name with the `filename` argument:
|
||||
|
||||
```python
|
||||
outputs.workflow.save_workflow("my-simple-workflow-sd", filename="my_workflow.json")
|
||||
```
|
||||
|
||||
You can also set `push_to_hub=True` in [`~Workflow.save_workflow`] to directly push the workflow object to the Hub.
|
||||
|
||||
## Loading a workflow
|
||||
|
||||
You can load a workflow in a pipeline with [`~DiffusionPipeline.load_workflow`]:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
|
||||
```
|
||||
|
||||
Once the pipeline is loaded with the desired workflow, it's ready to be called:
|
||||
|
||||
```python
|
||||
image = pipeline().images[0]
|
||||
```
|
||||
|
||||
By default, while loading a workflow, the scheduler of the underlying pipeline from the workflow isn't modified but you can change it by adding `load_scheduler=True`:
|
||||
|
||||
```
|
||||
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd", load_scheduler=True)
|
||||
```
|
||||
|
||||
This is particularly useful if you have changed the scheduler after loading a pipeline.
|
||||
|
||||
You can also override the pipeline call arguments. For example, to add a `negative_prompt`:
|
||||
|
||||
```python
|
||||
image = pipeline(negative_prompt="bad quality").images[0]
|
||||
```
|
||||
|
||||
Loading from a workflow is possible by specifying the `filename` argument inside the [`DiffusionPipeline.load_workflow`] method.
|
||||
|
||||
A workflow doesn't necessarily have to be used with the same pipeline that generated it. You can use it with a different pipeline too:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
|
||||
image = pipeline().images[0]
|
||||
```
|
||||
|
||||
However, make sure to thoroughly inspect the values you are calling the pipeline with, in this case.
|
||||
|
||||
Loading from a local workflow is also possible:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_workflow("path_to_local_dir")
|
||||
image = pipeline().images[0]
|
||||
```
|
||||
|
||||
Alternatively, if you want to load a workflow file and populate the pipeline arguments manually:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import json
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
with open("path_to_workflow_file.json") as f:
|
||||
workflow = json.load(f)
|
||||
|
||||
pipeline.load_workflow(workflow)
|
||||
images = pipeline().images[0]
|
||||
```
|
||||
|
||||
## Unsupported serialization types
|
||||
|
||||
Image-to-image pipelines like [`StableDiffusionControlNetPipeline`] accept one or more images in their `call` method. Currently, workflows don't support serializing `call` arguments that are of type `PIL.Image.Image` or `List[PIL.Image.Image]`. To make those pipelines work with workflows, you need to pass the images manually.
|
||||
|
||||
Let's say you generated the workflow below:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
|
||||
from diffusers.utils import load_image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
# download an image
|
||||
image = load_image(
|
||||
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
|
||||
)
|
||||
image = np.array(image)
|
||||
|
||||
# get canny image
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
# load control net and stable diffusion v1-5
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# generate image
|
||||
generator = torch.manual_seed(0)
|
||||
outputs = pipe(
|
||||
prompt="futuristic-looking office",
|
||||
image=canny_image,
|
||||
num_inference_steps=20,
|
||||
generator=generator,
|
||||
return_workflow=True
|
||||
)
|
||||
workflow = outputs.workflow
|
||||
```
|
||||
|
||||
If you look at the workflow, you'll see the image that was passed to the pipeline isn't included:
|
||||
|
||||
```bash
|
||||
{'prompt': 'futuristic-looking office',
|
||||
'height': None,
|
||||
'width': None,
|
||||
'num_inference_steps': 20,
|
||||
'guidance_scale': 7.5,
|
||||
'negative_prompt': None,
|
||||
'eta': 0.0,
|
||||
'latents': None,
|
||||
'prompt_embeds': None,
|
||||
'negative_prompt_embeds': None,
|
||||
'output_type': 'pil',
|
||||
'return_dict': True,
|
||||
'callback': None,
|
||||
'callback_steps': 1,
|
||||
'cross_attention_kwargs': None,
|
||||
'controlnet_conditioning_scale': 1.0,
|
||||
'guess_mode': False,
|
||||
'control_guidance_start': 0.0,
|
||||
'control_guidance_end': 1.0,
|
||||
'clip_skip': None,
|
||||
'generator_seed': 0,
|
||||
'generator_device': 'cpu',
|
||||
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
|
||||
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
|
||||
('beta_start', 0.00085),
|
||||
('beta_end', 0.012),
|
||||
('beta_schedule', 'scaled_linear'),
|
||||
('trained_betas', None),
|
||||
('solver_order', 2),
|
||||
('prediction_type', 'epsilon'),
|
||||
('thresholding', False),
|
||||
('dynamic_thresholding_ratio', 0.995),
|
||||
('sample_max_value', 1.0),
|
||||
('predict_x0', True),
|
||||
('solver_type', 'bh2'),
|
||||
('lower_order_final', True),
|
||||
('disable_corrector', []),
|
||||
('solver_p', None),
|
||||
('use_karras_sigmas', False),
|
||||
('timestep_spacing', 'linspace'),
|
||||
('steps_offset', 1),
|
||||
('_use_default_values',
|
||||
['lower_order_final',
|
||||
'sample_max_value',
|
||||
'solver_p',
|
||||
'dynamic_thresholding_ratio',
|
||||
'thresholding',
|
||||
'solver_type',
|
||||
'prediction_type',
|
||||
'predict_x0',
|
||||
'use_karras_sigmas',
|
||||
'disable_corrector',
|
||||
'timestep_spacing',
|
||||
'solver_order']),
|
||||
('skip_prk_steps', True),
|
||||
('set_alpha_to_one', False),
|
||||
('_class_name', 'PNDMScheduler'),
|
||||
('_diffusers_version', '0.6.0'),
|
||||
('clip_sample', False)])}
|
||||
```
|
||||
|
||||
|
||||
Let's serialize the workflow and reload the pipeline to see what happens when you try to use it.
|
||||
|
||||
```python
|
||||
workflow.save_workflow("my-simple-workflow-sd", filename="controlnet_simple.json", push_to_hub=True)
|
||||
```
|
||||
|
||||
Then load the workflow into [`StableDiffusionControlNetPipeline`]:
|
||||
|
||||
```python
|
||||
# load control net and stable diffusion v1-5
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
pipe.load_workflow("sayakpaul/my-simple-workflow-sd", filename="controlnet_simple.json")
|
||||
```
|
||||
|
||||
If you try to generate an image now, it'll return the following error:
|
||||
|
||||
```bash
|
||||
TypeError: image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is <class 'NoneType'>
|
||||
```
|
||||
|
||||
To resolve the error, manually pass the conditioning image `canny_image`:
|
||||
|
||||
```python
|
||||
image = pipe(image=canny_image).images[0]
|
||||
```
|
||||
|
||||
Other unsupported serialization types include:
|
||||
|
||||
* LoRA checkpoints: any information from LoRA checkpoints that might be loaded into a pipeline isn't serialized. Workflows generated from pipelines loaded with a LoRA checkpoint should be handled cautiously! You should ensure the LoRA checkpoint is loaded into the pipeline first before loading the corresponding workflow.
|
||||
* Call arguments including the following types: `torch.Tensor`, `np.ndarray`, `Callable`, `PIL.Image.Image`, and `List[PIL.Image.Image]`.
|
||||
@@ -752,6 +752,7 @@ class StableDiffusionControlNetPipeline(
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
return_workflow: bool = False,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
@@ -824,6 +825,8 @@ class StableDiffusionControlNetPipeline(
|
||||
The percentage of total steps at which the ControlNet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the ControlNet stops applying.
|
||||
return_workflow (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return used pipeline call arguments.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
@@ -837,6 +840,14 @@ class StableDiffusionControlNetPipeline(
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
# We do this first to capture the "True" call values. If we do this at a later point in time,
|
||||
# we cannot ensure that the call values weren't changed during the process.
|
||||
workflow = None
|
||||
if return_workflow:
|
||||
if generator is None:
|
||||
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
|
||||
workflow = self.populate_workflow_from_pipeline()
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
@@ -1075,6 +1086,11 @@ class StableDiffusionControlNetPipeline(
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
outputs = (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
if return_workflow:
|
||||
outputs += (workflow,)
|
||||
|
||||
return outputs
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
|
||||
|
||||
@@ -22,6 +22,7 @@ import re
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
@@ -54,7 +55,9 @@ from ..utils import (
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
from ..utils.constants import WORKFLOW_NAME
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
from ..workflow_utils import _NON_CALL_ARGUMENTS, Workflow
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -64,6 +67,7 @@ if is_transformers_available():
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
|
||||
|
||||
|
||||
@@ -2075,3 +2079,117 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
for module in modules:
|
||||
module.set_attention_slice(slice_size)
|
||||
|
||||
def populate_workflow_from_pipeline(self) -> Dict:
|
||||
r"""Populates the call arguments in a dictionary.
|
||||
|
||||
Returns:
|
||||
[`Workflow`]: A dictionary containing the details of the pipeline call arguments and (optionally) LoRA
|
||||
checkpoint details.
|
||||
"""
|
||||
# A `Workflow` object is an extended Python dictionary. So, all regular dictionary methods
|
||||
# apply to it.
|
||||
workflow = Workflow()
|
||||
|
||||
signature = inspect.signature(self.__call__)
|
||||
argument_names = [param.name for param in signature.parameters.values()]
|
||||
call_arg_values = inspect.getargvalues(inspect.currentframe().f_back).locals
|
||||
|
||||
# Populate call arguments.
|
||||
call_arguments = {
|
||||
arg: call_arg_values[arg]
|
||||
for arg in argument_names
|
||||
if arg != "return_workflow"
|
||||
and "image" not in arg
|
||||
and not isinstance(call_arg_values[arg], (torch.Tensor, np.ndarray, Callable))
|
||||
}
|
||||
workflow.update(call_arguments)
|
||||
|
||||
# Handle generator device and seed.
|
||||
generator = workflow["generator"]
|
||||
if isinstance(generator, list):
|
||||
for g in generator:
|
||||
if "generator_seed" not in workflow:
|
||||
workflow.update({"generator_seed": [g.initial_seed()]})
|
||||
workflow.update({"generator_device": [str(g.device)]})
|
||||
workflow.update({"generator_state": g.get_state().numpy().tolist()})
|
||||
else:
|
||||
workflow["generator_seed"].append(g.initial_seed())
|
||||
workflow["generator_device"].append(g.device)
|
||||
workflow["generator_state"].append(g.get_state().numpy().tolist())
|
||||
else:
|
||||
workflow.update({"generator_seed": generator.initial_seed()})
|
||||
workflow.update({"generator_device": str(generator.device)})
|
||||
workflow.update({"generator_state": generator.get_state().numpy().tolist()})
|
||||
|
||||
workflow.pop("generator")
|
||||
|
||||
# Handle pipeline-level things.
|
||||
if hasattr(self, "config") and hasattr(self.config, "_name_or_path"):
|
||||
pipeline_config_name_or_path = self.config._name_or_path
|
||||
else:
|
||||
pipeline_config_name_or_path = None
|
||||
workflow["_name_or_path"] = pipeline_config_name_or_path
|
||||
workflow["scheduler_config"] = self.scheduler.config
|
||||
|
||||
return workflow
|
||||
|
||||
def load_workflow(
|
||||
self,
|
||||
workflow_id_or_path: Union[str, dict],
|
||||
filename: Optional[str] = None,
|
||||
):
|
||||
r"""Loads a workflow from the Hub or from a local path. Also patches the pipeline call arguments with values from the
|
||||
workflow.
|
||||
|
||||
Args:
|
||||
workflow_id_or_path (`str` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the workflow id (for example `sayakpaul/sdxl-workflow`) of a workflow hosted on the
|
||||
Hub.
|
||||
- A path to a directory (for example `./my_workflow_directory`) containing the workflow file with
|
||||
[`Workflow.save_workflow`] or [`Workflow.push_to_hub`].
|
||||
- A Python dictionary.
|
||||
|
||||
filename (`str`, *optional*):
|
||||
Optional name of the workflow file to load. Especially useful when working with multiple workflow
|
||||
files.
|
||||
"""
|
||||
filename = filename or WORKFLOW_NAME
|
||||
|
||||
# Load workflow.
|
||||
if not isinstance(workflow_id_or_path, dict):
|
||||
if os.path.isdir(workflow_id_or_path):
|
||||
workflow_filepath = os.path.join(workflow_id_or_path, filename)
|
||||
elif os.path.isfile(workflow_id_or_path):
|
||||
workflow_filepath = workflow_id_or_path
|
||||
else:
|
||||
workflow_filepath = hf_hub_download(repo_id=workflow_id_or_path, filename=filename)
|
||||
workflow = self._dict_from_json_file(workflow_filepath)
|
||||
else:
|
||||
workflow = workflow_id_or_path
|
||||
|
||||
# We make a copy of the original workflow and operate on it.
|
||||
workflow_copy = dict(workflow.items())
|
||||
|
||||
# Handle generator.
|
||||
seed = workflow_copy.pop("generator_seed")
|
||||
device = workflow_copy.pop("generator_device", "cpu")
|
||||
last_known_state = workflow_copy.pop("generator_state")
|
||||
if isinstance(seed, list):
|
||||
generator = [
|
||||
torch.Generator(device=d).manual_seed(s).set_state(torch.from_numpy(np.array(lst)).byte())
|
||||
for s, d, lst in zip(seed, device, last_known_state)
|
||||
]
|
||||
else:
|
||||
last_known_state = torch.from_numpy(np.array(last_known_state)).byte()
|
||||
generator = torch.Generator(device=device).manual_seed(seed).set_state(last_known_state)
|
||||
workflow_copy.update({"generator": generator})
|
||||
|
||||
# Handle non-call arguments.
|
||||
final_call_args = {k: v for k, v in workflow_copy.items() if k not in _NON_CALL_ARGUMENTS}
|
||||
|
||||
# Handle the call here.
|
||||
partial_call = partial(self.__call__, **final_call_args)
|
||||
setattr(self.__class__, "__call__", partial_call)
|
||||
|
||||
@@ -19,10 +19,13 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
||||
`None` if safety checking could not be performed.
|
||||
workflow (`dict`):
|
||||
Dictionary containing pipeline component configurations and call arguments
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
workflow: Optional[dict] = None
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
|
||||
@@ -623,6 +623,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
return_workflow: bool = False,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
@@ -677,6 +678,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
return_workflow (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return used pipeline call arguments.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
@@ -699,6 +702,13 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
# We do this first to capture the "True" call values. If we do this at a later point in time,
|
||||
# we cannot ensure that the call values weren't changed during the process.
|
||||
workflow = None
|
||||
if return_workflow:
|
||||
if generator is None:
|
||||
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
|
||||
workflow = self.populate_workflow_from_pipeline()
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
@@ -855,6 +865,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
outputs = (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
if return_workflow:
|
||||
outputs += (workflow,)
|
||||
|
||||
return outputs
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
|
||||
from packaging import version
|
||||
|
||||
@@ -32,11 +33,13 @@ FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
WORKFLOW_NAME = "diffusion_workflow.json"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
MAX_SEED = np.iinfo(np.int32).max
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
161
src/diffusers/workflow_utils.py
Normal file
161
src/diffusers/workflow_utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for managing workflows."""
|
||||
import json
|
||||
import os
|
||||
from pathlib import PosixPath
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import create_repo
|
||||
|
||||
from . import __version__
|
||||
from .utils import PushToHubMixin, logging
|
||||
from .utils.constants import WORKFLOW_NAME
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_NON_CALL_ARGUMENTS = {"_name_or_path", "scheduler_config", "_class_name", "_diffusers_version"}
|
||||
|
||||
|
||||
class Workflow(dict, PushToHubMixin):
|
||||
"""Class sub-classing from native Python dict to have support for interacting with the Hub."""
|
||||
|
||||
config_name = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config_name = WORKFLOW_NAME
|
||||
self._internal_dict = {}
|
||||
|
||||
def __setitem__(self, __key, __value):
|
||||
self._internal_dict[__key] = __value
|
||||
return super().__setitem__(__key, __value)
|
||||
|
||||
def update(self, __m, **kwargs):
|
||||
self._internal_dict.update(__m, **kwargs)
|
||||
super().update(__m, **kwargs)
|
||||
|
||||
def pop(self, key, *args):
|
||||
self._internal_dict.pop(key, *args)
|
||||
return super().pop(key, *args)
|
||||
|
||||
# Copied from diffusers.configuration_utils.ConfigMixin.to_json_string
|
||||
def to_json_string(self) -> str:
|
||||
"""
|
||||
Serializes the configuration instance to a JSON string.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
String containing all the attributes that make up the configuration instance in JSON format.
|
||||
"""
|
||||
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
||||
config_dict["_class_name"] = self.__class__.__name__
|
||||
config_dict["_diffusers_version"] = __version__
|
||||
|
||||
def to_json_saveable(value):
|
||||
if isinstance(value, np.ndarray):
|
||||
value = value.tolist()
|
||||
elif isinstance(value, PosixPath):
|
||||
value = str(value)
|
||||
return value
|
||||
|
||||
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
||||
# Don't save "_ignore_files" or "_use_default_values"
|
||||
config_dict.pop("_ignore_files", None)
|
||||
config_dict.pop("_use_default_values", None)
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def save_workflow(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
push_to_hub: bool = False,
|
||||
filename: str = WORKFLOW_NAME,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Saves a workflow to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the workflow JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
filename (`str`, *optional*, defaults to `workflow.json`):
|
||||
Optional filename to use to serialize the workflow JSON.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.config_name = filename
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
output_config_file = os.path.join(save_directory, self.config_name)
|
||||
with open(output_config_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string())
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
push_to_hub: bool = False,
|
||||
filename: str = WORKFLOW_NAME,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Saves a workflow to a directory. This internally calls [`Workflow.save_workflow`], This method exists to have
|
||||
feature parity with [`PushToHubMixin.push_to_hub`].
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the workflow JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
filename (`str`, *optional*, defaults to `workflow.json`):
|
||||
Optional filename to use to serialize the workflow JSON.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.save_workflow(
|
||||
save_directory=save_directory,
|
||||
push_to_hub=push_to_hub,
|
||||
filename=filename,
|
||||
**kwargs,
|
||||
)
|
||||
171
tests/others/test_workflows.py
Normal file
171
tests/others/test_workflows.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import delete_repo, hf_hub_download
|
||||
from test_utils import TOKEN, USER, is_staging_test
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.constants import WORKFLOW_NAME
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
from diffusers.workflow_utils import Workflow
|
||||
|
||||
|
||||
class WorkflowFastTests(unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=1,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=3,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_workflow_with_stable_diffusion(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = sd_pipe(**inputs, return_workflow=True)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
output.workflow.save_pretrained(tmpdirname)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
sd_pipe.load_workflow(tmpdirname)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = sd_pipe(**inputs)
|
||||
image = output.images
|
||||
workflow_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(np.allclose(image_slice, workflow_image_slice))
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class WorkflowPushToHubTester(unittest.TestCase):
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-workflow-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
|
||||
def compare_workflow_values(self, repo_id: str, actual_workflow: dict):
|
||||
local_path = hf_hub_download(repo_id=repo_id, filename=WORKFLOW_NAME, token=TOKEN)
|
||||
with open(local_path) as f:
|
||||
locally_loaded_workflow = json.load(f)
|
||||
for k in actual_workflow:
|
||||
assert actual_workflow[k] == locally_loaded_workflow[k]
|
||||
|
||||
def test_push_to_hub(self):
|
||||
workflow = Workflow()
|
||||
workflow.update({"prompt": "hey", "num_inference_steps": 25})
|
||||
|
||||
workflow.push_to_hub(self.repo_id, token=TOKEN)
|
||||
self.compare_workflow_values(repo_id=f"{USER}/{self.repo_id}", actual_workflow=workflow)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
workflow = Workflow()
|
||||
workflow.update({"prompt": "hey", "num_inference_steps": 25})
|
||||
|
||||
workflow.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||
self.compare_workflow_values(repo_id=self.org_repo_id, actual_workflow=workflow)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
Reference in New Issue
Block a user