Compare commits

...

1 Commits

Author SHA1 Message Date
Pedro Cuenca
53352312e8 Use negative prompts in JAX docs. 2023-07-24 11:52:24 +02:00

View File

@@ -2,9 +2,9 @@
[[open-in-colab]]
🤗 Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax since version `0.5.1`! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
🤗 Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [this notebook](https://huggingface.co/docs/diffusers/stable_diffusion).
This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [the documentation](https://huggingface.co/docs/diffusers/stable_diffusion).
First, make sure you are using a TPU backend. If you are running this notebook in Colab, select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware accelerator` setting.
@@ -16,10 +16,12 @@ First make sure diffusers is installed.
```py
# uncomment to install the necessary libraries in Colab
#!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
#!pip install jax jaxlib flax transformers ftfy
#!pip install diffusers
```
When running in Colab, run the following cell to setup the TPU environment.
```python
import jax.tools.colab_tpu
@@ -93,6 +95,14 @@ prompt_ids.shape
(8, 77)
```
We can also use negative prompts to specify concepts we want to steer away from during generation.
```python
neg_prompt = "ugly, low-res, malformed, oversaturated"
neg_prompt = [neg_prompt] * jax.device_count()
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
```
### Replication and parallelization
Model parameters and inputs have to be replicated across the 8 parallel devices we have. The parameters dictionary is replicated using `flax.jax_utils.replicate`, which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`.
@@ -103,6 +113,7 @@ p_params = replicate(params)
```python
prompt_ids = shard(prompt_ids)
neg_prompt_ids = shard(neg_prompt_ids)
prompt_ids.shape
```
@@ -136,7 +147,7 @@ The first time we run the following cell it will take a long time to compile, bu
```
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
images = pipeline(prompt_ids, p_params, rng, neg_prompt_ids=neg_prompt_ids, jit=True)[0]
```
```python out
@@ -202,50 +213,3 @@ image_grid(images, 2, 4)
```
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)
## How does parallelization work?
We said before that the `diffusers` Flax pipeline automatically compiles the model and runs it in parallel on all available devices. We'll now briefly look inside that process to show how it works.
JAX parallelization can be done in multiple ways. The easiest one revolves around using the `jax.pmap` function to achieve single-program, multiple-data (SPMD) parallelization. It means we'll run several copies of the same code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the [JAX documentation](https://jax.readthedocs.io/en/latest/index.html) and the [`pjit` pages](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit) to explore this topic if you are interested!
`jax.pmap` does two things for us:
- Compiles (or `jit`s) the code, as if we had invoked `jax.jit()`. This does not happen when we call `pmap`, but the first time the pmapped function is invoked.
- Ensures the compiled code runs in parallel in all the available devices.
To show how it works we `pmap` the `_generate` method of the pipeline, which is the private method that runs generates images. Please, note that this method may be renamed or removed in future releases of `diffusers`.
```python
p_generate = pmap(pipeline._generate)
```
After we use `pmap`, the prepared function `p_generate` will conceptually do the following:
* Invoke a copy of the underlying function `pipeline._generate` in each device.
* Send each device a different portion of the input arguments. That's what sharding is used for. In our case, `prompt_ids` has shape `(8, 1, 77, 768)`. This array will be split in `8` and each copy of `_generate` will receive an input with shape `(1, 77, 768)`.
We can code `_generate` completely ignoring the fact that it will be invoked in parallel. We just care about our batch size (`1` in this example) and the dimensions that make sense for our code, and don't have to change anything to make it work in parallel.
The same way as when we used the pipeline call, the first time we run the following cell it will take a while, but then it will be much faster.
```
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
```
```python out
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
```
```python
images.shape
```
```python out
(8, 1, 512, 512, 3)
```
We use `block_until_ready()` to correctly measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don't need to use that in your code; blocking will occur automatically when you want to use the result of a computation that has not yet been materialized.