mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 22:14:43 +08:00
Compare commits
1 Commits
fix/lora-l
...
docs-flax-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53352312e8 |
@@ -2,9 +2,9 @@
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
🤗 Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax since version `0.5.1`! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
|
||||
🤗 Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
|
||||
|
||||
This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [this notebook](https://huggingface.co/docs/diffusers/stable_diffusion).
|
||||
This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [the documentation](https://huggingface.co/docs/diffusers/stable_diffusion).
|
||||
|
||||
First, make sure you are using a TPU backend. If you are running this notebook in Colab, select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware accelerator` setting.
|
||||
|
||||
@@ -16,10 +16,12 @@ First make sure diffusers is installed.
|
||||
|
||||
```py
|
||||
# uncomment to install the necessary libraries in Colab
|
||||
#!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
|
||||
#!pip install jax jaxlib flax transformers ftfy
|
||||
#!pip install diffusers
|
||||
```
|
||||
|
||||
When running in Colab, run the following cell to setup the TPU environment.
|
||||
|
||||
```python
|
||||
import jax.tools.colab_tpu
|
||||
|
||||
@@ -93,6 +95,14 @@ prompt_ids.shape
|
||||
(8, 77)
|
||||
```
|
||||
|
||||
We can also use negative prompts to specify concepts we want to steer away from during generation.
|
||||
|
||||
```python
|
||||
neg_prompt = "ugly, low-res, malformed, oversaturated"
|
||||
neg_prompt = [neg_prompt] * jax.device_count()
|
||||
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
|
||||
```
|
||||
|
||||
### Replication and parallelization
|
||||
|
||||
Model parameters and inputs have to be replicated across the 8 parallel devices we have. The parameters dictionary is replicated using `flax.jax_utils.replicate`, which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`.
|
||||
@@ -103,6 +113,7 @@ p_params = replicate(params)
|
||||
|
||||
```python
|
||||
prompt_ids = shard(prompt_ids)
|
||||
neg_prompt_ids = shard(neg_prompt_ids)
|
||||
prompt_ids.shape
|
||||
```
|
||||
|
||||
@@ -136,7 +147,7 @@ The first time we run the following cell it will take a long time to compile, bu
|
||||
|
||||
```
|
||||
%%time
|
||||
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
|
||||
images = pipeline(prompt_ids, p_params, rng, neg_prompt_ids=neg_prompt_ids, jit=True)[0]
|
||||
```
|
||||
|
||||
```python out
|
||||
@@ -202,50 +213,3 @@ image_grid(images, 2, 4)
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
## How does parallelization work?
|
||||
|
||||
We said before that the `diffusers` Flax pipeline automatically compiles the model and runs it in parallel on all available devices. We'll now briefly look inside that process to show how it works.
|
||||
|
||||
JAX parallelization can be done in multiple ways. The easiest one revolves around using the `jax.pmap` function to achieve single-program, multiple-data (SPMD) parallelization. It means we'll run several copies of the same code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the [JAX documentation](https://jax.readthedocs.io/en/latest/index.html) and the [`pjit` pages](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit) to explore this topic if you are interested!
|
||||
|
||||
`jax.pmap` does two things for us:
|
||||
- Compiles (or `jit`s) the code, as if we had invoked `jax.jit()`. This does not happen when we call `pmap`, but the first time the pmapped function is invoked.
|
||||
- Ensures the compiled code runs in parallel in all the available devices.
|
||||
|
||||
To show how it works we `pmap` the `_generate` method of the pipeline, which is the private method that runs generates images. Please, note that this method may be renamed or removed in future releases of `diffusers`.
|
||||
|
||||
```python
|
||||
p_generate = pmap(pipeline._generate)
|
||||
```
|
||||
|
||||
After we use `pmap`, the prepared function `p_generate` will conceptually do the following:
|
||||
* Invoke a copy of the underlying function `pipeline._generate` in each device.
|
||||
* Send each device a different portion of the input arguments. That's what sharding is used for. In our case, `prompt_ids` has shape `(8, 1, 77, 768)`. This array will be split in `8` and each copy of `_generate` will receive an input with shape `(1, 77, 768)`.
|
||||
|
||||
We can code `_generate` completely ignoring the fact that it will be invoked in parallel. We just care about our batch size (`1` in this example) and the dimensions that make sense for our code, and don't have to change anything to make it work in parallel.
|
||||
|
||||
The same way as when we used the pipeline call, the first time we run the following cell it will take a while, but then it will be much faster.
|
||||
|
||||
```
|
||||
%%time
|
||||
images = p_generate(prompt_ids, p_params, rng)
|
||||
images = images.block_until_ready()
|
||||
images.shape
|
||||
```
|
||||
|
||||
```python out
|
||||
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
|
||||
Wall time: 1min 15s
|
||||
```
|
||||
|
||||
```python
|
||||
images.shape
|
||||
```
|
||||
|
||||
```python out
|
||||
(8, 1, 512, 512, 3)
|
||||
```
|
||||
|
||||
We use `block_until_ready()` to correctly measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don't need to use that in your code; blocking will occur automatically when you want to use the result of a computation that has not yet been materialized.
|
||||
Reference in New Issue
Block a user