mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-03 09:25:06 +08:00
Compare commits
38 Commits
qwen-test-
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3249d7b2e | ||
|
|
85682000a9 | ||
|
|
5fefef9bc9 | ||
|
|
ea815e5bb0 | ||
|
|
7eb51e932f | ||
|
|
079e0e31b7 | ||
|
|
f9bdc09534 | ||
|
|
2bee621229 | ||
|
|
7a0739ccd3 | ||
|
|
b4b707e585 | ||
|
|
fefd0f4e45 | ||
|
|
6e8e7bad9e | ||
|
|
0eaa35fdca | ||
|
|
4dff31871c | ||
|
|
515dd06db5 | ||
|
|
5274ffdd7f | ||
|
|
a21a6ac565 | ||
|
|
c2d8273891 | ||
|
|
e1249d2640 | ||
|
|
2fe9f9868d | ||
|
|
387befd6de | ||
|
|
351316328f | ||
|
|
62bf2b0ab9 | ||
|
|
7f2cd5b6fc | ||
|
|
4ea43ee6ab | ||
|
|
084c959bdf | ||
|
|
3dcb97c9ea | ||
|
|
7b55da8846 | ||
|
|
cec020988b | ||
|
|
926db24add | ||
|
|
37cfceef0d | ||
|
|
ea90a74ed4 | ||
|
|
96f08043a3 | ||
|
|
d0f279ce76 | ||
|
|
c5e023fbe6 | ||
|
|
f8e50fab75 | ||
|
|
c152b1831c | ||
|
|
039324ae16 |
27
.github/workflows/pr_tests.yml
vendored
27
.github/workflows/pr_tests.yml
vendored
@@ -92,8 +92,9 @@ jobs:
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_example_cpu
|
||||
transformers_version: ["main"]
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
name: ${{ matrix.config.name }} (transformers ${{ matrix.transformers_version }})
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
@@ -115,8 +116,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
if [ "${{ matrix.transformers_version }}" = "main" ]; then
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
else
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==${{ matrix.transformers_version }}
|
||||
fi
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
@@ -155,7 +159,7 @@ jobs:
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_transformers_${{ matrix.transformers_version }}_test_reports
|
||||
path: reports
|
||||
|
||||
run_staging_tests:
|
||||
@@ -220,8 +224,10 @@ jobs:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
transformers_version: ["main"]
|
||||
|
||||
name: LoRA tests with PEFT main
|
||||
name: LoRA tests with PEFT main (transformers ${{ matrix.transformers_version }})
|
||||
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
@@ -247,9 +253,12 @@ jobs:
|
||||
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
uv pip install -U tokenizers
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
if [ "${{ matrix.transformers_version }}" = "main" ]; then
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
else
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==${{ matrix.transformers_version }}
|
||||
fi
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -275,6 +284,6 @@ jobs:
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_main_test_reports
|
||||
name: pr_lora_transformers_${{ matrix.transformers_version }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
42
.github/workflows/pr_tests_gpu.yml
vendored
42
.github/workflows/pr_tests_gpu.yml
vendored
@@ -14,6 +14,7 @@ on:
|
||||
- "tests/pipelines/test_pipelines_common.py"
|
||||
- "tests/models/test_modeling_common.py"
|
||||
- "examples/**/*.py"
|
||||
- ".github/**.yml"
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -106,13 +107,14 @@ jobs:
|
||||
path: reports
|
||||
|
||||
torch_pipelines_cuda_tests:
|
||||
name: Torch Pipelines CUDA Tests
|
||||
name: Torch Pipelines CUDA Tests (transformers ${{ matrix.transformers_version }})
|
||||
needs: setup_torch_cuda_pipeline_matrix
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 8
|
||||
matrix:
|
||||
module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
|
||||
transformers_version: ["main"]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
@@ -131,8 +133,12 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
if [ "${{ matrix.transformers_version }}" = "main" ]; then
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
else
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==${{ matrix.transformers_version }}
|
||||
fi
|
||||
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -172,11 +178,11 @@ jobs:
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
name: pipeline_${{ matrix.module }}_transformers_${{ matrix.transformers_version }}_test_reports
|
||||
path: reports
|
||||
|
||||
torch_cuda_tests:
|
||||
name: Torch CUDA Tests
|
||||
name: Torch CUDA Tests (transformers ${{ matrix.transformers_version }})
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
@@ -191,6 +197,7 @@ jobs:
|
||||
max-parallel: 4
|
||||
matrix:
|
||||
module: [models, schedulers, lora, others]
|
||||
transformers_version: ["main"]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
@@ -202,8 +209,12 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
if [ "${{ matrix.transformers_version }}" = "main" ]; then
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
else
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==${{ matrix.transformers_version }}
|
||||
fi
|
||||
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -241,12 +252,16 @@ jobs:
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}_transformers_${{ matrix.transformers_version }}
|
||||
path: reports
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
name: Examples PyTorch CUDA tests on Ubuntu (transformers ${{ matrix.transformers_version }})
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
transformers_version: ["main"]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
@@ -264,8 +279,11 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
if [ "${{ matrix.transformers_version }}" = "main" ]; then
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
else
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==${{ matrix.transformers_version }}
|
||||
fi
|
||||
uv pip install -e ".[quality,training]"
|
||||
|
||||
- name: Environment
|
||||
@@ -289,6 +307,6 @@ jobs:
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: examples_test_reports
|
||||
name: examples_transformers_${{ matrix.transformers_version }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ The Modular Diffusers docs are organized as shown below.
|
||||
|
||||
## Quickstart
|
||||
|
||||
- The [quickstart](./quickstart) shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.
|
||||
- A [quickstart](./quickstart) demonstrating how to implement an example workflow with Modular Diffusers.
|
||||
|
||||
## ModularPipelineBlocks
|
||||
|
||||
|
||||
@@ -12,248 +12,333 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Quickstart
|
||||
|
||||
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. These pipelines can go beyond what standard `DiffusionPipeline`s can do. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface for running generation tasks.
|
||||
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface developers can use.
|
||||
|
||||
This guide shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.
|
||||
This doc will show you how to implement a [Differential Diffusion](https://differential-diffusion.github.io/) pipeline with the modular framework.
|
||||
|
||||
## Run a pipeline
|
||||
## ModularPipelineBlocks
|
||||
|
||||
[`ModularPipelineBlocks`] are *definitions* that specify the components, inputs, outputs, and computation logic for a single step in a pipeline. There are four types of blocks.
|
||||
|
||||
- [`ModularPipelineBlocks`] is the most basic block for a single step.
|
||||
- [`SequentialPipelineBlocks`] is a multi-block that composes other blocks linearly. The outputs of one block are the inputs to the next block.
|
||||
- [`LoopSequentialPipelineBlocks`] is a multi-block that runs iteratively and is designed for iterative workflows.
|
||||
- [`AutoPipelineBlocks`] is a collection of blocks for different workflows and it selects which block to run based on the input. It is designed to conveniently package multiple workflows into a single pipeline.
|
||||
|
||||
[Differential Diffusion](https://differential-diffusion.github.io/) is an image-to-image workflow. Start with the `IMAGE2IMAGE_BLOCKS` preset, a collection of `ModularPipelineBlocks` for image-to-image generation.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("decode", StableDiffusionXLDecodeStep)
|
||||
])
|
||||
```
|
||||
|
||||
## Pipeline and block states
|
||||
|
||||
Modular Diffusers uses *state* to communicate data between blocks. There are two types of states.
|
||||
|
||||
- [`PipelineState`] is a global state that can be used to track all inputs and outputs across all blocks.
|
||||
- [`BlockState`] is a local view of relevant variables from [`PipelineState`] for an individual block.
|
||||
|
||||
## Customizing blocks
|
||||
|
||||
[Differential Diffusion](https://differential-diffusion.github.io/) differs from standard image-to-image in its `prepare_latents` and `denoise` blocks. All the other blocks can be reused, but you'll need to modify these two.
|
||||
|
||||
Create placeholder `ModularPipelineBlocks` for `prepare_latents` and `denoise` by copying and modifying the existing ones.
|
||||
|
||||
Print the `denoise` block to see that it is composed of [`LoopSequentialPipelineBlocks`] with three sub-blocks, `before_denoiser`, `denoiser`, and `after_denoiser`. Only the `before_denoiser` sub-block needs to be modified to prepare the latent input for the denoiser based on the change map.
|
||||
|
||||
```py
|
||||
denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
|
||||
print(denoise_blocks)
|
||||
```
|
||||
|
||||
Replace the `StableDiffusionXLLoopBeforeDenoiser` sub-block with the new `SDXLDiffDiffLoopBeforeDenoiser` block.
|
||||
|
||||
```py
|
||||
# Copy existing blocks as placeholders
|
||||
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
|
||||
"""Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
|
||||
# ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
|
||||
|
||||
class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
```
|
||||
|
||||
### prepare_latents
|
||||
|
||||
The `prepare_latents` block requires the following changes.
|
||||
|
||||
- a processor to process the change map
|
||||
- a new `inputs` to accept the user-provided change map, `timestep` for precomputing all the latents and `num_inference_steps` to create the mask for updating the image regions
|
||||
- update the computation in the `__call__` method for processing the change map and creating the masks, and storing it in the [`BlockState`]
|
||||
|
||||
```diff
|
||||
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
|
||||
]
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
+ InputParam("diffdiff_map", required=True),
|
||||
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
|
||||
+ InputParam("timesteps", type_hint=torch.Tensor),
|
||||
+ InputParam("num_inference_steps", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
+ OutputParam("original_latents", type_hint=torch.Tensor),
|
||||
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
def __call__(self, components, state: PipelineState):
|
||||
# ... existing logic ...
|
||||
+ # Process change map and create masks
|
||||
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
|
||||
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
|
||||
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
|
||||
+ block_state.original_latents = block_state.latents
|
||||
```
|
||||
|
||||
### denoise
|
||||
|
||||
The `before_denoiser` sub-block requires the following changes.
|
||||
|
||||
- a new `inputs` to accept a `denoising_start` parameter, `original_latents` and `diffdiff_masks` from the `prepare_latents` block
|
||||
- update the computation in the `__call__` method for applying Differential Diffusion
|
||||
|
||||
```diff
|
||||
class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
+ InputParam("denoising_start"),
|
||||
+ InputParam("original_latents", type_hint=torch.Tensor),
|
||||
+ InputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
def __call__(self, components, block_state, i, t):
|
||||
+ # Apply differential diffusion logic
|
||||
+ if i == 0 and block_state.denoising_start is None:
|
||||
+ block_state.latents = block_state.original_latents[:1]
|
||||
+ else:
|
||||
+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
|
||||
+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
|
||||
|
||||
# ... rest of existing logic ...
|
||||
```
|
||||
|
||||
## Assembling the blocks
|
||||
|
||||
You should have all the blocks you need at this point to create a [`ModularPipeline`].
|
||||
|
||||
Copy the existing `IMAGE2IMAGE_BLOCKS` preset and for the `set_timesteps` block, use the `set_timesteps` from the `TEXT2IMAGE_BLOCKS` because Differential Diffusion doesn't require a `strength` parameter.
|
||||
|
||||
Set the `prepare_latents` and `denoise` blocks to the `SDXLDiffDiffPrepareLatentsStep` and `SDXLDiffDiffDenoiseStep` blocks you just modified.
|
||||
|
||||
Call [`SequentialPipelineBlocks.from_blocks_dict`] on the blocks to create a `SequentialPipelineBlocks`.
|
||||
|
||||
```py
|
||||
DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
|
||||
|
||||
dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
|
||||
print(dd_blocks)
|
||||
```
|
||||
|
||||
## ModularPipeline
|
||||
|
||||
Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_components`].
|
||||
|
||||
It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ComponentsManager
|
||||
|
||||
components = ComponentManager()
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
|
||||
dd_pipeline.load_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
## Adding workflows
|
||||
|
||||
Other workflows can be added to the [`ModularPipeline`] to support additional features without rewriting the entire pipeline from scratch.
|
||||
|
||||
This section demonstrates how to add an IP-Adapter or ControlNet.
|
||||
|
||||
### IP-Adapter
|
||||
|
||||
Stable Diffusion XL already has a preset IP-Adapter block that you can use and doesn't require any changes to the existing Differential Diffusion pipeline.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
|
||||
|
||||
ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
|
||||
```
|
||||
|
||||
Use the [`sub_blocks.insert`] method to insert it into the [`ModularPipeline`]. The example below inserts the `ip_adapter_block` at position `0`. Print the pipeline to see that the `ip_adapter_block` is added and it requires an `ip_adapter_image`. This also added two components to the pipeline, the `image_encoder` and `feature_extractor`.
|
||||
|
||||
```py
|
||||
dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
|
||||
```
|
||||
|
||||
Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_components`] to load the model components. Load and set the IP-Adapter to run the pipeline.
|
||||
|
||||
```py
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
dd_pipeline.loader.set_ip_adapter_scale(0.6)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
|
||||
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
|
||||
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
|
||||
prompt = "a green pear"
|
||||
negative_prompt = "blurry"
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
|
||||
image = dd_pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=25,
|
||||
generator=generator,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
diffdiff_map=mask,
|
||||
image=image,
|
||||
output="images"
|
||||
)[0]
|
||||
```
|
||||
|
||||
### ControlNet
|
||||
|
||||
Stable Diffusion XL already has a preset ControlNet block that can readily be used.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
|
||||
|
||||
control_input_block = StableDiffusionXLAutoControlNetInputStep()
|
||||
```
|
||||
|
||||
However, it requires modifying the `denoise` block because that's where the ControlNet injects the control information into the UNet.
|
||||
|
||||
Modify the `denoise` block by replacing the `StableDiffusionXLLoopDenoiser` sub-block with the `StableDiffusionXLControlNetLoopDenoiser`.
|
||||
|
||||
```py
|
||||
class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
|
||||
```
|
||||
|
||||
Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_components`] into it.
|
||||
|
||||
```py
|
||||
dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
|
||||
dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
|
||||
control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
|
||||
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
|
||||
prompt = "a green pear"
|
||||
negative_prompt = "blurry"
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
|
||||
image = dd_pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=25,
|
||||
generator=generator,
|
||||
control_image=control_image,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
diffdiff_map=mask,
|
||||
image=image,
|
||||
output="images"
|
||||
)[0]
|
||||
```
|
||||
|
||||
### AutoPipelineBlocks
|
||||
|
||||
The Differential Diffusion, IP-Adapter, and ControlNet workflows can be bundled into a single [`ModularPipeline`] by using [`AutoPipelineBlocks`]. This allows automatically selecting which sub-blocks to run based on the inputs like `control_image` or `ip_adapter_image`. If none of these inputs are passed, then it defaults to the Differential Diffusion.
|
||||
|
||||
Use `block_trigger_inputs` to only run the `SDXLDiffDiffControlNetDenoiseStep` block if a `control_image` input is provided. Otherwise, the `SDXLDiffDiffDenoiseStep` is used.
|
||||
|
||||
```py
|
||||
class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
|
||||
block_names = ["controlnet_denoise", "denoise"]
|
||||
block_trigger_inputs = ["controlnet_cond", None]
|
||||
```
|
||||
|
||||
Add the `ip_adapter` and `controlnet_input` blocks.
|
||||
|
||||
```py
|
||||
DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
|
||||
DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
|
||||
DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
|
||||
```
|
||||
|
||||
Call [`SequentialPipelineBlocks.from_blocks_dict`] to create a [`SequentialPipelineBlocks`] and create a [`ModularPipeline`] and load in the model components to run.
|
||||
|
||||
```py
|
||||
dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
|
||||
dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## Share
|
||||
|
||||
Add your [`ModularPipeline`] to the Hub with [`~ModularPipeline.save_pretrained`] and set `push_to_hub` argument to `True`.
|
||||
|
||||
```py
|
||||
dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
|
||||
```
|
||||
|
||||
Other users can load the [`ModularPipeline`] with [`~ModularPipeline.from_pretrained`].
|
||||
|
||||
[`ModularPipeline`] is the main interface for loading, running, and managing modular pipelines.
|
||||
```py
|
||||
import torch
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
|
||||
|
||||
# Use ComponentsManager to enable auto CPU offloading for memory efficiency
|
||||
manager = ComponentsManager()
|
||||
manager.enable_auto_cpu_offload(device="cuda:0")
|
||||
components = ComponentsManager()
|
||||
|
||||
pipe = ModularPipeline.from_pretrained("Qwen/Qwen-Image", components_manager=manager)
|
||||
pipe.load_components(torch_dtype=torch.bfloat16)
|
||||
|
||||
image = pipe(
|
||||
prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
|
||||
).images[0]
|
||||
image
|
||||
diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
|
||||
diffdiff_pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
[`~ModularPipeline.from_pretrained`] uses lazy loading - it reads the configuration to learn where to load each component from, but doesn't actually load the model weights until you call [`~ModularPipeline.load_components`]. This gives you control over when and how components are loaded.
|
||||
|
||||
> [!TIP]
|
||||
> [`ComponentsManager`] with `enable_auto_cpu_offload` automatically moves models between CPU and GPU as needed, reducing memory usage for large models like Qwen-Image. Learn more in the [ComponentsManager](./components_manager) guide.
|
||||
|
||||
Learn more about creating and loading pipelines in the [Creating a pipeline](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#creating-a-pipeline) and [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guides.
|
||||
|
||||
## Understand the structure
|
||||
|
||||
A [`ModularPipeline`] has two parts:
|
||||
- **State**: the loaded components (models, schedulers, processors) and configuration
|
||||
- **Definition**: the [`ModularPipelineBlocks`] that specify inputs, outputs, expected components and computation logic
|
||||
|
||||
The blocks define *what* the pipeline does. Access them through `pipe.blocks`.
|
||||
```py
|
||||
print(pipe.blocks)
|
||||
```
|
||||
```
|
||||
QwenImageAutoBlocks(
|
||||
Class: SequentialPipelineBlocks
|
||||
|
||||
Description: Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image2image`: requires `prompt`, `image`
|
||||
- `inpainting`: requires `prompt`, `mask_image`, `image`
|
||||
- `controlnet_text2image`: requires `prompt`, `control_image`
|
||||
...
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`)
|
||||
vae (`AutoencoderKLQwenImage`)
|
||||
transformer (`QwenImageTransformer2DModel`)
|
||||
...
|
||||
|
||||
Sub-Blocks:
|
||||
[0] text_encoder (QwenImageAutoTextEncoderStep)
|
||||
[1] vae_encoder (QwenImageAutoVaeEncoderStep)
|
||||
[2] controlnet_vae_encoder (QwenImageOptionalControlNetVaeEncoderStep)
|
||||
[3] denoise (QwenImageAutoCoreDenoiseStep)
|
||||
[4] decode (QwenImageAutoDecodeStep)
|
||||
)
|
||||
```
|
||||
|
||||
The output returns:
|
||||
- The supported workflows (text2image, image2image, inpainting, etc.)
|
||||
- The Sub-Blocks it's composed of (text_encoder, vae_encoder, denoise, decode)
|
||||
|
||||
### Workflows
|
||||
|
||||
`QwenImageAutoBlocks` is a [`ConditionalPipelineBlocks`], so this pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image. Let's see this in action with an example.
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
input_image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
|
||||
|
||||
image = pipe(
|
||||
prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney",
|
||||
image=input_image,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
Use `get_workflow()` to extract the blocks for a specific workflow. Pass the workflow name (e.g., `"image2image"`, `"inpainting"`, `"controlnet_text2image"`) to get only the blocks relevant to that workflow.
|
||||
```py
|
||||
img2img_blocks = pipe.blocks.get_workflow("image2image")
|
||||
```
|
||||
|
||||
Conditional blocks are convenient for users, but their conditional logic adds complexity when customizing or debugging. Extracting a workflow gives you the specific blocks relevant to your workflow, making it easier to work with. Learn more in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.
|
||||
|
||||
### Sub-blocks
|
||||
|
||||
Blocks can contain other blocks. `pipe.blocks` gives you the top-level block definition (here, `QwenImageAutoBlocks`), while `sub_blocks` lets you access the smaller blocks inside it.
|
||||
|
||||
`QwenImageAutoBlocks` is composed of: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`. Access them through the `sub_blocks` property.
|
||||
|
||||
The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.
|
||||
```py
|
||||
vae_encoder_block = pipe.blocks.sub_blocks["vae_encoder"]
|
||||
print(vae_encoder_block.doc)
|
||||
```
|
||||
|
||||
This block can be converted to a pipeline so that it can run on its own with [`~ModularPipelineBlocks.init_pipeline`].
|
||||
```py
|
||||
vae_encoder_pipe = vae_encoder_block.init_pipeline()
|
||||
|
||||
# Reuse the VAE we already loaded, we can reuse it with update_components() method
|
||||
vae_encoder_pipe.update_components(vae=pipe.vae)
|
||||
|
||||
# Run just this block
|
||||
image_latents = vae_encoder_pipe(image=input_image).image_latents
|
||||
print(image_latents.shape)
|
||||
```
|
||||
|
||||
It reuses the VAE from our original pipeline instead of reloading it, keeping memory usage efficient. Learn more in the [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guide.
|
||||
|
||||
Since blocks are composable, you can modify the pipeline's definition by adding, removing, or swapping blocks to create new workflows. In the next section, we'll add a canny edge detection block to a ControlNet pipeline, so you can pass a regular image instead of a pre-processed canny edge map.
|
||||
|
||||
## Compose new workflows
|
||||
|
||||
Let's add a canny edge detection block to a ControlNet pipeline. First, load a pre-built canny block from the Hub (see [Building Custom Blocks](https://huggingface.co/docs/diffusers/modular_diffusers/custom_blocks) to create your own).
|
||||
```py
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks
|
||||
|
||||
# Load a canny block from the Hub
|
||||
canny_block = ModularPipelineBlocks.from_pretrained(
|
||||
"diffusers-internal-dev/canny-filtering",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
print(canny_block.doc)
|
||||
```
|
||||
```
|
||||
class CannyBlock
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, ndarray]`):
|
||||
Image to compute canny filter on
|
||||
low_threshold (`int`, *optional*, defaults to 50):
|
||||
Low threshold for the canny filter.
|
||||
high_threshold (`int`, *optional*, defaults to 200):
|
||||
High threshold for the canny filter.
|
||||
...
|
||||
|
||||
Outputs:
|
||||
control_image (`PIL.Image`):
|
||||
Canny map for input image
|
||||
```
|
||||
|
||||
UUse `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`].
|
||||
```py
|
||||
# Get the controlnet workflow that we want to work with
|
||||
blocks = pipe.blocks.get_workflow("controlnet_text2image")
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
class SequentialPipelineBlocks
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
control_image (`Image`):
|
||||
Control image for ControlNet conditioning.
|
||||
...
|
||||
```
|
||||
|
||||
The extracted workflow is a [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) - a multi-block type where blocks run one after another and data flows linearly from one block to the next. Each block's `intermediate_outputs` become available as `inputs` to subsequent blocks.
|
||||
|
||||
Currently this workflow requires `control_image` as input. Let's insert the canny block at the beginning so the pipeline accepts a regular image instead.
|
||||
```py
|
||||
# Insert canny at the beginning
|
||||
blocks.sub_blocks.insert("canny", canny_block, 0)
|
||||
|
||||
# Check the updated structure: CannyBlock is now listed as first sub-block
|
||||
print(blocks)
|
||||
# Check the updated doc
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
class SequentialPipelineBlocks
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, ndarray]`):
|
||||
Image to compute canny filter on
|
||||
low_threshold (`int`, *optional*, defaults to 50):
|
||||
Low threshold for the canny filter.
|
||||
high_threshold (`int`, *optional*, defaults to 200):
|
||||
High threshold for the canny filter.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
...
|
||||
```
|
||||
|
||||
Now the pipeline takes `image` as input instead of `control_image`. Because blocks in a sequence share data automatically, the canny block's output (`control_image`) flows to the denoise block that needs it, and the canny block's input (`image`) becomes a pipeline input since no earlier block provides it.
|
||||
|
||||
Create a pipeline from the modified blocks and load a ControlNet model.
|
||||
```py
|
||||
pipeline = blocks.init_pipeline("Qwen/Qwen-Image", components_manager=manager)
|
||||
|
||||
pipeline.load_components(torch_dtype=torch.bfloat16)
|
||||
|
||||
# Load the ControlNet model
|
||||
controlnet_spec = pipeline.get_component_spec("controlnet")
|
||||
controlnet_spec.pretrained_model_name_or_path = "InstantX/Qwen-Image-ControlNet-Union"
|
||||
controlnet = controlnet_spec.load(torch_dtype=torch.bfloat16)
|
||||
pipeline.update_components(controlnet=controlnet)
|
||||
```
|
||||
|
||||
Now run the pipeline - the canny block preprocesses the image for ControlNet.
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
prompt = "cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney"
|
||||
image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
).images[0]
|
||||
output
|
||||
```
|
||||
|
||||
## Next steps
|
||||
|
||||
<hfoptions id="next">
|
||||
<hfoption id="Build custom blocks">
|
||||
|
||||
Learn how to create your own blocks with custom logic in the [Building Custom Blocks](./custom_blocks) guide.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Share components">
|
||||
|
||||
Use [`ComponentsManager`](./components_manager) to share models across multiple pipelines and manage memory efficiently.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Visual interface">
|
||||
|
||||
Connect modular pipelines to [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows. Custom blocks built with Modular Diffusers work out of the box with Mellon - no UI code required. Read more in the Mellon guide.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
@@ -343,34 +343,6 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](
|
||||
|
||||
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
|
||||
|
||||
|
||||
### Ulysses Anything Attention
|
||||
|
||||
The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.
|
||||
|
||||
[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True))
|
||||
```
|
||||
|
||||
> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.
|
||||
|
||||
We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows:
|
||||
|
||||
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
|
||||
|--------------------|------------------|-------------|------------------|------------|
|
||||
| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 |
|
||||
| ring | 351.34 | 2.85 | 37.01 | 1024x1024 |
|
||||
| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 |
|
||||
| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 |
|
||||
| ulysses | failed | failed | failed | 1008x1008 |
|
||||
| ring | failed | failed | failed | 1008x1008 |
|
||||
| unified_balanced | failed | failed | failed | 1008x1008 |
|
||||
| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 |
|
||||
|
||||
From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.
|
||||
|
||||
### parallel_config
|
||||
|
||||
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||
|
||||
@@ -17,6 +17,9 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
@@ -30,6 +33,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
@unittest.skipIf(is_transformers_version(">=", "4.57.5"), "Size mismatch")
|
||||
class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
@@ -35,8 +35,8 @@ from . import BaseDiffusersCLICommand
|
||||
def conversion_command_factory(args: Namespace):
|
||||
if args.use_auth_token:
|
||||
warnings.warn(
|
||||
"The `--use_auth_token` flag is deprecated and will be removed in a future version."
|
||||
"Authentication is now handled automatically if the user is logged in."
|
||||
"The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
|
||||
" handled automatically if user is logged in."
|
||||
)
|
||||
return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
|
||||
|
||||
@@ -92,8 +92,8 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
|
||||
pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
|
||||
self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
|
||||
|
||||
# Load the appropriate pipeline. We could have used `DiffusionPipeline`
|
||||
# here, but just to avoid potential edge cases.
|
||||
# Load the appropriate pipeline. We could have use `DiffusionPipeline`
|
||||
# here, but just to avoid any rough edge cases.
|
||||
pipeline = pipeline_class.from_pretrained(
|
||||
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
|
||||
)
|
||||
|
||||
@@ -44,6 +44,7 @@ _GO_LC_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d,
|
||||
torch.nn.Linear,
|
||||
torch.nn.Embedding,
|
||||
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
||||
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
||||
)
|
||||
|
||||
@@ -11,14 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple, Type, Union
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
@@ -29,10 +27,9 @@ from ..models._modeling_parallel import (
|
||||
ContextParallelInput,
|
||||
ContextParallelModelPlan,
|
||||
ContextParallelOutput,
|
||||
gather_size_by_comm,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -211,10 +208,6 @@ class ContextParallelSplitHook(ModelHook):
|
||||
)
|
||||
return x
|
||||
else:
|
||||
if self.parallel_config.ulysses_anything:
|
||||
return PartitionAnythingSharder.shard_anything(
|
||||
x, cp_input.split_dim, self.parallel_config._flattened_mesh
|
||||
)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
@@ -240,14 +233,7 @@ class ContextParallelGatherHook(ModelHook):
|
||||
for i, cpm in enumerate(self.metadata):
|
||||
if cpm is None:
|
||||
continue
|
||||
if self.parallel_config.ulysses_anything:
|
||||
output[i] = PartitionAnythingSharder.unshard_anything(
|
||||
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
|
||||
)
|
||||
else:
|
||||
output[i] = EquipartitionSharder.unshard(
|
||||
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
|
||||
)
|
||||
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
@@ -288,73 +274,6 @@ class EquipartitionSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
class AllGatherAnythingFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
|
||||
ctx.dim = dim
|
||||
ctx.group = group
|
||||
ctx.world_size = dist.get_world_size(group)
|
||||
ctx.rank = dist.get_rank(group)
|
||||
gathered_tensor = _all_gather_anything(tensor, dim, group)
|
||||
return gathered_tensor
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
|
||||
# function may return fewer than the specified number of chunks!
|
||||
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
|
||||
return grad_splits[ctx.rank], None, None
|
||||
|
||||
|
||||
class PartitionAnythingSharder:
|
||||
@classmethod
|
||||
def shard_anything(
|
||||
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
|
||||
) -> torch.Tensor:
|
||||
assert tensor.size()[dim] >= mesh.size(), (
|
||||
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
|
||||
)
|
||||
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
|
||||
# function may return fewer than the specified number of chunks!
|
||||
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]
|
||||
|
||||
@classmethod
|
||||
def unshard_anything(
|
||||
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
|
||||
) -> torch.Tensor:
|
||||
tensor = tensor.contiguous()
|
||||
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
|
||||
return tensor
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]:
|
||||
gather_shapes = []
|
||||
for i in range(world_size):
|
||||
rank_shape = list(copy.deepcopy(shape))
|
||||
rank_shape[dim] = gather_dims[i]
|
||||
gather_shapes.append(rank_shape)
|
||||
return gather_shapes
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
world_size = dist.get_world_size(group=group)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
shape = tensor.shape
|
||||
rank_dim = shape[dim]
|
||||
gather_dims = gather_size_by_comm(rank_dim, group)
|
||||
|
||||
gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)
|
||||
|
||||
gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]
|
||||
|
||||
dist.all_gather(gathered_tensors, tensor, group=group)
|
||||
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
|
||||
return gathered_tensor
|
||||
|
||||
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name.count("*") > 1:
|
||||
raise ValueError("Wildcard '*' can only be used once in the name")
|
||||
|
||||
@@ -21,7 +21,12 @@ from tokenizers import Tokenizer as TokenizerFast
|
||||
from torch import nn
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
|
||||
from ..utils import (
|
||||
_get_model_file,
|
||||
is_accelerate_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
|
||||
@@ -19,7 +19,6 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
@@ -68,9 +67,6 @@ class ContextParallelConfig:
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
# Whether to enable ulysses anything attention to support
|
||||
# any sequence lengths and any head numbers.
|
||||
ulysses_anything: bool = False
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
@@ -98,11 +94,6 @@ class ContextParallelConfig:
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
if self.ulysses_anything:
|
||||
if self.ulysses_degree == 1:
|
||||
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
|
||||
if self.ring_degree > 1:
|
||||
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> Tuple[int, int]:
|
||||
@@ -266,39 +257,3 @@ ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextPara
|
||||
#
|
||||
# ContextParallelOutput:
|
||||
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
|
||||
|
||||
|
||||
# Below are utility functions for distributed communication in context parallelism.
|
||||
def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
|
||||
r"""Gather the local size from all ranks.
|
||||
size: int, local size return: List[int], list of size from all ranks
|
||||
"""
|
||||
# NOTE(Serving/CP Safety):
|
||||
# Do NOT cache this collective result.
|
||||
#
|
||||
# In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL)
|
||||
# may legitimately differ across ranks. If we cache based on the *local* `size`,
|
||||
# different ranks can have different cache hit/miss patterns across time.
|
||||
#
|
||||
# That can lead to a catastrophic distributed hang:
|
||||
# - some ranks hit cache and *skip* dist.all_gather()
|
||||
# - other ranks miss cache and *enter* dist.all_gather()
|
||||
# This mismatched collective participation will stall the process group and
|
||||
# eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL
|
||||
# timeouts in Ulysses attention).
|
||||
world_size = dist.get_world_size(group=group)
|
||||
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
|
||||
comm_backends = str(dist.get_backend(group=group))
|
||||
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
|
||||
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
|
||||
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
|
||||
dist.all_gather(
|
||||
gathered_sizes,
|
||||
torch.tensor([size], device=gather_device, dtype=torch.int64),
|
||||
group=group,
|
||||
)
|
||||
|
||||
gathered_sizes = [s[0].item() for s in gathered_sizes]
|
||||
# NOTE: DON'T use tolist here due to graph break - Explanation:
|
||||
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
|
||||
return gathered_sizes
|
||||
|
||||
@@ -21,8 +21,6 @@ from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
@@ -46,8 +44,6 @@ from ..utils import (
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from ._modeling_parallel import gather_size_by_comm
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -1283,154 +1279,6 @@ class SeqAllToAllDim(torch.autograd.Function):
|
||||
return (None, grad_input, None, None)
|
||||
|
||||
|
||||
# Below are helper functions to handle abritrary head num and abritrary sequence length for Ulysses Anything Attention.
|
||||
def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
|
||||
r"""Maybe pad the head dimension to be divisible by world_size.
|
||||
x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded
|
||||
tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
H_PAD = 0
|
||||
if H % world_size != 0:
|
||||
H_PAD = world_size - (H % world_size)
|
||||
NEW_H_LOCAL = (H + H_PAD) // world_size
|
||||
# e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2.
|
||||
# NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14.
|
||||
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
|
||||
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
|
||||
return x, H_PAD
|
||||
|
||||
|
||||
def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
|
||||
r"""Maybe unpad the head dimension.
|
||||
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
|
||||
unpadded tensor (B, S_GLOBAL, H_LOCAL, D)
|
||||
"""
|
||||
rank = dist.get_rank(group=group)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
# Only the last rank may have padding
|
||||
if H_PAD > 0 and rank == world_size - 1:
|
||||
x = x[:, :, :-H_PAD, :]
|
||||
return x.contiguous()
|
||||
|
||||
|
||||
def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
|
||||
r"""Maybe pad the head dimension to be divisible by world_size.
|
||||
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int],
|
||||
padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD
|
||||
"""
|
||||
if H is None:
|
||||
return x, 0
|
||||
|
||||
rank = dist.get_rank(group=group)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
H_PAD = 0
|
||||
# Only the last rank may need padding
|
||||
if H % world_size != 0:
|
||||
# We need to broadcast H_PAD to all ranks to keep consistency
|
||||
# in unpadding step later for all ranks.
|
||||
H_PAD = world_size - (H % world_size)
|
||||
NEW_H_LOCAL = (H + H_PAD) // world_size
|
||||
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
|
||||
if rank == world_size - 1:
|
||||
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
|
||||
return x, H_PAD
|
||||
|
||||
|
||||
def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
|
||||
r"""Maybe unpad the head dimension.
|
||||
x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
|
||||
unpadded tensor (B, S_LOCAL, H_GLOBAL, D)
|
||||
"""
|
||||
if H_PAD > 0:
|
||||
x = x[:, :, :-H_PAD, :]
|
||||
return x.contiguous()
|
||||
|
||||
|
||||
def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict:
|
||||
# query: (B, S_LOCAL, H_GLOBAL, D)
|
||||
assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)"
|
||||
extra_kwargs = {}
|
||||
extra_kwargs["NUM_QO_HEAD"] = query.shape[2]
|
||||
extra_kwargs["Q_S_LOCAL"] = query.shape[1]
|
||||
# Add other kwargs if needed in future
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
def all_to_all_single_any_qkv_async(
|
||||
x: torch.Tensor, group: dist.ProcessGroup, **kwargs
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
r"""
|
||||
x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
B, S_LOCAL, H, D = x.shape
|
||||
x, H_PAD = _maybe_pad_qkv_head(x, H, group)
|
||||
H_LOCAL = (H + H_PAD) // world_size
|
||||
# (world_size, S_LOCAL, B, H_LOCAL, D)
|
||||
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
|
||||
|
||||
input_split_sizes = [S_LOCAL] * world_size
|
||||
# S_LOCAL maybe not equal for all ranks in dynamic shape case,
|
||||
# since we don't know the actual shape before this timing, thus,
|
||||
# we have to use all gather to collect the S_LOCAL first.
|
||||
output_split_sizes = gather_size_by_comm(S_LOCAL, group)
|
||||
x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D)
|
||||
x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
|
||||
|
||||
def wait() -> torch.Tensor:
|
||||
nonlocal x, H_PAD
|
||||
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
|
||||
# (S_GLOBAL, B, H_LOCAL, D)
|
||||
# -> (B, S_GLOBAL, H_LOCAL, D)
|
||||
x = x.permute(1, 0, 2, 3).contiguous()
|
||||
x = _maybe_unpad_qkv_head(x, H_PAD, group)
|
||||
return x
|
||||
|
||||
return wait
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]:
|
||||
r"""
|
||||
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
|
||||
"""
|
||||
# Assume H is provided in kwargs, since we can't infer H from x's shape.
|
||||
# The padding logic needs H to determine if padding is necessary.
|
||||
H = kwargs.get("NUM_QO_HEAD", None)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
|
||||
x, H_PAD = _maybe_pad_o_head(x, H, group)
|
||||
shape = x.shape # (B, S_GLOBAL, H_LOCAL, D)
|
||||
(B, S_GLOBAL, H_LOCAL, D) = shape
|
||||
|
||||
# input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..]
|
||||
# output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..]
|
||||
|
||||
# WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer
|
||||
# from tensor split due to: if c = torch.cat((a, b)), world_size=4, then,
|
||||
# c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] +
|
||||
# b.tensor_split(4)[0].shape[1])
|
||||
|
||||
S_LOCAL = kwargs.get("Q_S_LOCAL")
|
||||
input_split_sizes = gather_size_by_comm(S_LOCAL, group)
|
||||
x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D)
|
||||
output_split_sizes = [S_LOCAL] * world_size
|
||||
x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
|
||||
|
||||
def wait() -> torch.Tensor:
|
||||
nonlocal x, H_PAD
|
||||
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
|
||||
x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D)
|
||||
x = x.permute(2, 1, 0, 3, 4).contiguous()
|
||||
x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D)
|
||||
x = _maybe_unpad_o_head(x, H_PAD, group)
|
||||
return x
|
||||
|
||||
return wait
|
||||
|
||||
|
||||
class TemplatedRingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -1653,82 +1501,6 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class TemplatedUlyssesAnythingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
scale: Optional[float],
|
||||
enable_gqa: bool,
|
||||
return_lse: bool,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
|
||||
group = ulysses_mesh.get_group()
|
||||
|
||||
ctx.forward_op = forward_op
|
||||
ctx.backward_op = backward_op
|
||||
ctx._parallel_config = _parallel_config
|
||||
|
||||
metadata = ulysses_anything_metadata(query)
|
||||
query_wait = all_to_all_single_any_qkv_async(query, group, **metadata)
|
||||
key_wait = all_to_all_single_any_qkv_async(key, group, **metadata)
|
||||
value_wait = all_to_all_single_any_qkv_async(value, group, **metadata)
|
||||
|
||||
query = query_wait() # type: torch.Tensor
|
||||
key = key_wait() # type: torch.Tensor
|
||||
value = value_wait() # type: torch.Tensor
|
||||
|
||||
out = forward_op(
|
||||
ctx,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
_save_ctx=False, # ulysses anything only support forward pass now.
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
# out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D)
|
||||
out_wait = all_to_all_single_any_o_async(out, group, **metadata)
|
||||
|
||||
if return_lse:
|
||||
# lse: (B, S_Q_GLOBAL, H_LOCAL)
|
||||
lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
|
||||
lse_wait = all_to_all_single_any_o_async(lse, group, **metadata)
|
||||
out = out_wait() # type: torch.Tensor
|
||||
lse = lse_wait() # type: torch.Tensor
|
||||
lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL)
|
||||
else:
|
||||
out = out_wait() # type: torch.Tensor
|
||||
lse = None
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.")
|
||||
|
||||
|
||||
def _templated_unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -1846,37 +1618,20 @@ def _templated_context_parallel_attention(
|
||||
_parallel_config,
|
||||
)
|
||||
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
|
||||
if _parallel_config.context_parallel_config.ulysses_anything:
|
||||
# For Any sequence lengths and Any head num support
|
||||
return TemplatedUlyssesAnythingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
else:
|
||||
return TemplatedUlyssesAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
return TemplatedUlyssesAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")
|
||||
|
||||
|
||||
@@ -287,6 +287,9 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_ids = (
|
||||
input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids
|
||||
)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
input_ids_batch.append(input_ids)
|
||||
|
||||
|
||||
@@ -407,8 +407,8 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
|
||||
if len(source_grids) > 0:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], source_grids
|
||||
).pooler_output
|
||||
inputs["pixel_values"], source_grids, return_dict=False
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids_d32 = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, source_grids
|
||||
|
||||
@@ -20,6 +20,8 @@ class MultilingualCLIP(PreTrainedModel):
|
||||
self.LinearTransformation = torch.nn.Linear(
|
||||
in_features=config.transformerDimensions, out_features=config.numDims
|
||||
)
|
||||
if hasattr(self, "post_init"):
|
||||
self.post_init()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
@@ -782,6 +782,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
self.prefix_encoder = PrefixEncoder(config)
|
||||
self.dropout = torch.nn.Dropout(0.1)
|
||||
|
||||
if hasattr(self, "post_init"):
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embedding.word_embeddings
|
||||
|
||||
@@ -811,7 +814,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", None)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
@@ -340,6 +340,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
save_method_accept_peft_format = "save_peft_format" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
@@ -349,6 +350,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
# max_shard_size is expected to not be None in ModelMixin
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
if save_method_accept_peft_format:
|
||||
# Set save_peft_format=False for transformers>=5.0.0 compatibility
|
||||
# In transformers 5.0.0+, the default save_peft_format=True adds "base_model.model" prefix
|
||||
# to adapter keys, but from_pretrained expects keys without this prefix
|
||||
save_kwargs["save_peft_format"] = False
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
|
||||
@@ -496,13 +496,8 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
patch_size = (
|
||||
self.transformer.config.patch_size
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size
|
||||
)
|
||||
h_multiple_of = self.vae_scale_factor_spatial * patch_size[1]
|
||||
w_multiple_of = self.vae_scale_factor_spatial * patch_size[2]
|
||||
h_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
w_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[2]
|
||||
calc_height = height // h_multiple_of * h_multiple_of
|
||||
calc_width = width // w_multiple_of * w_multiple_of
|
||||
if height != calc_height or width != calc_width:
|
||||
|
||||
@@ -637,13 +637,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
patch_size = (
|
||||
self.transformer.config.patch_size
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size
|
||||
)
|
||||
h_multiple_of = self.vae_scale_factor_spatial * patch_size[1]
|
||||
w_multiple_of = self.vae_scale_factor_spatial * patch_size[2]
|
||||
h_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
w_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[2]
|
||||
calc_height = height // h_multiple_of * h_multiple_of
|
||||
calc_width = width // w_multiple_of * w_multiple_of
|
||||
if height != calc_height or width != calc_width:
|
||||
|
||||
@@ -227,7 +227,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
|
||||
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
||||
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
|
||||
_aiter_available, _aiter_version = _is_package_available("aiter")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
|
||||
@@ -32,22 +32,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
||||
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
|
||||
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
|
||||
config.addinivalue_line("markers", "training: marks tests for training functionality")
|
||||
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
|
||||
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
|
||||
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
|
||||
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
|
||||
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
|
||||
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
|
||||
config.addinivalue_line("markers", "quantization: marks tests for quantization functionality")
|
||||
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
|
||||
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
|
||||
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
|
||||
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
|
||||
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
|
||||
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
|
||||
config.addinivalue_line("markers", "slow: mark test as slow")
|
||||
config.addinivalue_line("markers", "nightly: mark test as nightly")
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ class TestAutoModel(unittest.TestCase):
|
||||
side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
|
||||
)
|
||||
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
|
||||
model = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder", use_safetensors=False
|
||||
)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
|
||||
def test_load_from_config_without_subfolder(self):
|
||||
@@ -28,5 +30,7 @@ class TestAutoModel(unittest.TestCase):
|
||||
assert isinstance(model, LongformerModel)
|
||||
|
||||
def test_load_from_model_index(self):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
|
||||
model = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder", use_safetensors=False
|
||||
)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .cache import (
|
||||
CacheTesterMixin,
|
||||
FasterCacheConfigMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheConfigMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PyramidAttentionBroadcastConfigMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
)
|
||||
from .common import BaseModelTesterConfig, ModelTesterMixin
|
||||
from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .parallelism import ContextParallelTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesConfigMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFConfigMixin,
|
||||
GGUFTesterMixin,
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptConfigMixin,
|
||||
ModelOptTesterMixin,
|
||||
QuantizationCompileTesterMixin,
|
||||
QuantizationTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoConfigMixin,
|
||||
QuantoTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoConfigMixin,
|
||||
TorchAoTesterMixin,
|
||||
)
|
||||
from .single_file import SingleFileTesterMixin
|
||||
from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionTesterMixin",
|
||||
"BaseModelTesterConfig",
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
"BitsAndBytesConfigMixin",
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CacheTesterMixin",
|
||||
"ContextParallelTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"FasterCacheConfigMixin",
|
||||
"FasterCacheTesterMixin",
|
||||
"FirstBlockCacheConfigMixin",
|
||||
"FirstBlockCacheTesterMixin",
|
||||
"GGUFCompileTesterMixin",
|
||||
"GGUFConfigMixin",
|
||||
"GGUFTesterMixin",
|
||||
"GroupOffloadTesterMixin",
|
||||
"IPAdapterTesterMixin",
|
||||
"LayerwiseCastingTesterMixin",
|
||||
"LoraHotSwappingForModelTesterMixin",
|
||||
"LoraTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"ModelOptCompileTesterMixin",
|
||||
"ModelOptConfigMixin",
|
||||
"ModelOptTesterMixin",
|
||||
"ModelTesterMixin",
|
||||
"PyramidAttentionBroadcastConfigMixin",
|
||||
"PyramidAttentionBroadcastTesterMixin",
|
||||
"QuantizationCompileTesterMixin",
|
||||
"QuantizationTesterMixin",
|
||||
"QuantoCompileTesterMixin",
|
||||
"QuantoConfigMixin",
|
||||
"QuantoTesterMixin",
|
||||
"SingleFileTesterMixin",
|
||||
"TorchAoCompileTesterMixin",
|
||||
"TorchAoConfigMixin",
|
||||
"TorchAoTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
"TrainingTesterMixin",
|
||||
]
|
||||
@@ -1,181 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_attention,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention processor and module functionality on models.
|
||||
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if not hasattr(model, "fuse_qkv_projections"):
|
||||
pytest.skip("Model does not support QKV projection fusion.")
|
||||
|
||||
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model.fuse_qkv_projections()
|
||||
|
||||
has_fused_projections = False
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
|
||||
has_fused_projections = True
|
||||
assert module.fused_projections, "fused_projections flag should be True"
|
||||
break
|
||||
|
||||
if has_fused_projections:
|
||||
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_fusion,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should not change after fusing projections",
|
||||
)
|
||||
|
||||
model.unfuse_qkv_projections()
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
||||
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
||||
assert not module.fused_projections, "fused_projections flag should be False"
|
||||
|
||||
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_unfusion,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match original after unfusing projections",
|
||||
)
|
||||
|
||||
def test_get_set_processor(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# Check if model has attention processors
|
||||
if not hasattr(model, "attn_processors"):
|
||||
pytest.skip("Model does not have attention processors.")
|
||||
|
||||
# Test getting processors
|
||||
processors = model.attn_processors
|
||||
assert isinstance(processors, dict), "attn_processors should return a dict"
|
||||
assert len(processors) > 0, "Model should have at least one attention processor"
|
||||
|
||||
# Test that all processors can be retrieved via get_processor
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
processor = module.get_processor()
|
||||
assert processor is not None, "get_processor should return a processor"
|
||||
|
||||
# Test setting a new processor
|
||||
new_processor = AttnProcessor()
|
||||
module.set_processor(new_processor)
|
||||
retrieved_processor = module.get_processor()
|
||||
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
|
||||
|
||||
def test_attention_processor_dict(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict of new processors
|
||||
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
|
||||
|
||||
# Set processors using dict
|
||||
model.set_attn_processor(new_processors)
|
||||
|
||||
# Verify all processors were set
|
||||
updated_processors = model.attn_processors
|
||||
for key in current_processors.keys():
|
||||
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
|
||||
|
||||
def test_attention_processor_count_mismatch_raises_error(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict with wrong number of processors
|
||||
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
|
||||
|
||||
# Verify error is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
@@ -1,556 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
|
||||
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from diffusers.models.cache_utils import CacheMixin
|
||||
|
||||
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
|
||||
|
||||
|
||||
def require_cache_mixin(func):
|
||||
"""Decorator to skip tests if model doesn't use CacheMixin."""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not issubclass(self.model_class, CacheMixin):
|
||||
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CacheTesterMixin:
|
||||
"""
|
||||
Base mixin class providing common test implementations for cache testing.
|
||||
|
||||
Cache-specific mixins should:
|
||||
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
|
||||
2. Inherit from this mixin
|
||||
3. Define the cache config to use for tests
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods in test classes:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional overrides:
|
||||
- cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states")
|
||||
"""
|
||||
|
||||
@property
|
||||
def cache_input_key(self):
|
||||
return "hidden_states"
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _get_cache_config(self):
|
||||
"""
|
||||
Get the cache config for testing.
|
||||
Should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _get_cache_config")
|
||||
|
||||
def _get_hook_names(self):
|
||||
"""
|
||||
Get the hook names to check for this cache type.
|
||||
Should be implemented by subclasses.
|
||||
Returns a list of hook name strings.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _get_hook_names")
|
||||
|
||||
def _test_cache_enable_disable_state(self):
|
||||
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Initially cache should not be enabled
|
||||
assert not model.is_cache_enabled, "Cache should not be enabled initially."
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
# Enable cache
|
||||
model.enable_cache(config)
|
||||
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
|
||||
|
||||
# Disable cache
|
||||
model.disable_cache()
|
||||
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
|
||||
|
||||
def _test_cache_double_enable_raises_error(self):
|
||||
"""Test that enabling cache twice raises an error."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Trying to enable again should raise ValueError
|
||||
with pytest.raises(ValueError, match="Caching has already been enabled"):
|
||||
model.enable_cache(config)
|
||||
|
||||
# Cleanup
|
||||
model.disable_cache()
|
||||
|
||||
def _test_cache_hooks_registered(self):
|
||||
"""Test that cache hooks are properly registered and removed."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
hook_names = self._get_hook_names()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Check that at least one hook was registered
|
||||
hook_count = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
for hook_name in hook_names:
|
||||
hook = module._diffusers_hook.get_hook(hook_name)
|
||||
if hook is not None:
|
||||
hook_count += 1
|
||||
|
||||
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
|
||||
|
||||
# Disable and verify hooks are removed
|
||||
model.disable_cache()
|
||||
|
||||
hook_count_after = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
for hook_name in hook_names:
|
||||
hook = module._diffusers_hook.get_hook(hook_name)
|
||||
if hook is not None:
|
||||
hook_count_after += 1
|
||||
|
||||
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with cache enabled."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass (vary input tensor to simulate denoising)
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass uses cached attention with different inputs (produces approximated output)
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_context_manager(self, atol=1e-5, rtol=0):
|
||||
"""Test the cache_context context manager properly isolates cache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# Run inference in first context
|
||||
with model.cache_context("context_1"):
|
||||
output_ctx1 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Run same inference in second context (cache should be reset)
|
||||
with model.cache_context("context_2"):
|
||||
output_ctx2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Both contexts should produce the same output (first pass in each)
|
||||
assert_tensors_close(
|
||||
output_ctx1,
|
||||
output_ctx2,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="First pass in different cache contexts should produce the same output.",
|
||||
)
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the cache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class PyramidAttentionBroadcastConfigMixin:
|
||||
"""
|
||||
Base mixin providing PyramidAttentionBroadcast cache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default PAB config - can be overridden by subclasses
|
||||
PAB_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
}
|
||||
|
||||
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
|
||||
_current_timestep = 500
|
||||
|
||||
def _get_cache_config(self):
|
||||
config_kwargs = self.PAB_CONFIG.copy()
|
||||
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
|
||||
return PyramidAttentionBroadcastConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing PyramidAttentionBroadcast caching on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class FirstBlockCacheConfigMixin:
|
||||
"""
|
||||
Base mixin providing FirstBlockCache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default FBC config - can be overridden by subclasses
|
||||
# Higher threshold makes FBC more aggressive about caching (skips more often)
|
||||
FBC_CONFIG = {
|
||||
"threshold": 1.0,
|
||||
}
|
||||
|
||||
def _get_cache_config(self):
|
||||
return FirstBlockCacheConfig(**self.FBC_CONFIG)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing FirstBlockCache on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# FBC requires cache_context to be set for inference
|
||||
with model.cache_context("fbc_test"):
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass - FBC should skip remaining blocks and use cached residuals
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
with model.cache_context("fbc_test"):
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class FasterCacheConfigMixin:
|
||||
"""
|
||||
Base mixin providing FasterCache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default FasterCache config - can be overridden by subclasses
|
||||
FASTER_CACHE_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||
"tensor_format": "BCHW",
|
||||
}
|
||||
|
||||
def _get_cache_config(self, current_timestep_callback=None):
|
||||
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
|
||||
if current_timestep_callback is None:
|
||||
current_timestep_callback = lambda: 1000 # noqa: E731
|
||||
config_kwargs["current_timestep_callback"] = current_timestep_callback
|
||||
return FasterCacheConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing FasterCache on models.
|
||||
|
||||
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
|
||||
and timestep management. Inference tests are skipped at model level - FasterCache should
|
||||
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with FasterCache enabled."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
current_timestep = [1000]
|
||||
config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0])
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass with timestep outside skip range - computes and populates cache
|
||||
current_timestep[0] = 1000
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Move timestep inside skip range so subsequent passes use cache
|
||||
current_timestep[0] = 500
|
||||
|
||||
# Create modified inputs for second pass
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass uses cached attention with different inputs
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FasterCache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
@@ -1,666 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
|
||||
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
|
||||
|
||||
from ...testing_utils import assert_tensors_close, torch_device
|
||||
|
||||
|
||||
def named_persistent_module_tensors(
|
||||
module: nn.Module,
|
||||
recurse: bool = False,
|
||||
):
|
||||
"""
|
||||
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module we want the tensors on.
|
||||
recurse (`bool`, *optional`, defaults to `False`):
|
||||
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
||||
"""
|
||||
yield from module.named_parameters(recurse=recurse)
|
||||
|
||||
for named_buffer in module.named_buffers(recurse=recurse):
|
||||
name, _ = named_buffer
|
||||
# Get parent by splitting on dots and traversing the model
|
||||
parent = module
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
for part in parent_name.split("."):
|
||||
parent = getattr(parent, part)
|
||||
name = name.split(".")[-1]
|
||||
if name not in parent._non_persistent_buffers_set:
|
||||
yield named_buffer
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: str | torch.device | None = None,
|
||||
special_dtypes: dict[str, str | torch.device] | None = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
def calculate_expected_num_shards(index_map_path):
|
||||
"""
|
||||
Calculate expected number of shards from index file.
|
||||
|
||||
Args:
|
||||
index_map_path: Path to the sharded checkpoint index file
|
||||
|
||||
Returns:
|
||||
int: Expected number of shards
|
||||
"""
|
||||
with open(index_map_path) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
return expected_num_shards
|
||||
|
||||
|
||||
def check_device_map_is_respected(model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
while len(param_name) > 0 and param_name not in device_map:
|
||||
param_name = ".".join(param_name.split(".")[:-1])
|
||||
if param_name not in device_map:
|
||||
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
|
||||
else:
|
||||
assert param.device == torch.device(param_device), (
|
||||
f"Expected device {param_device} for {param_name}, got {param.device}"
|
||||
)
|
||||
|
||||
|
||||
def cast_inputs_to_dtype(inputs, current_dtype, target_dtype):
|
||||
if torch.is_tensor(inputs):
|
||||
return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs
|
||||
if isinstance(inputs, dict):
|
||||
return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()}
|
||||
if isinstance(inputs, list):
|
||||
return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class BaseModelTesterConfig:
|
||||
"""
|
||||
Base class defining the configuration interface for model testing.
|
||||
|
||||
This class defines the contract that all model test classes must implement.
|
||||
It provides a consistent interface for accessing model configuration, initialization
|
||||
parameters, and test inputs across all testing mixins.
|
||||
|
||||
Required properties (must be implemented by subclasses):
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties (can be overridden, have sensible defaults):
|
||||
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
|
||||
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
|
||||
- output_shape: Expected output shape for output validation tests (default: None)
|
||||
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
|
||||
|
||||
Required methods (must be implemented by subclasses):
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Example usage:
|
||||
class MyModelTestConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return MyModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "org/my-model"
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 3, 32, 32)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {"in_channels": 3, "out_channels": 3}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
|
||||
|
||||
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
|
||||
pass
|
||||
"""
|
||||
|
||||
# ==================== Required Properties ====================
|
||||
|
||||
@property
|
||||
def model_class(self) -> Type[nn.Module]:
|
||||
"""The model class to test. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `model_class` property.")
|
||||
|
||||
# ==================== Optional Properties ====================
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self) -> Optional[str]:
|
||||
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self) -> Dict[str, Any]:
|
||||
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Optional[tuple]:
|
||||
"""Expected output shape for output validation tests."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
"""Percentages for model parallelism tests."""
|
||||
return [0.9]
|
||||
|
||||
# ==================== Required Methods ====================
|
||||
|
||||
def get_init_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns dict of arguments to initialize the model.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Initialization arguments for the model constructor.
|
||||
|
||||
Example:
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"sample_size": 32,
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
|
||||
|
||||
def get_dummy_inputs(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns dict of inputs to pass to the model forward pass.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Input tensors/values for model.forward().
|
||||
|
||||
Example:
|
||||
return {
|
||||
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
|
||||
"timestep": torch.tensor([1], device=torch_device),
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
"""
|
||||
Base mixin class for model testing with common test methods.
|
||||
|
||||
This mixin expects the test class to also inherit from BaseModelTesterConfig
|
||||
(or implement its interface) which provides:
|
||||
- model_class: The model class to test
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Example:
|
||||
class MyModelTestConfig(BaseModelTesterConfig):
|
||||
model_class = MyModel
|
||||
def get_init_dict(self): ...
|
||||
def get_dummy_inputs(self): ...
|
||||
|
||||
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
|
||||
pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path)
|
||||
new_model.to(torch_device)
|
||||
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
assert param_1.shape == param_2.shape, (
|
||||
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
)
|
||||
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.save_pretrained(tmp_path, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
|
||||
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmp_path)
|
||||
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
|
||||
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
|
||||
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
@torch.no_grad()
|
||||
def test_determinism(self, atol=1e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
first_filtered = first_flat[mask]
|
||||
second_filtered = second_flat[mask]
|
||||
|
||||
assert_tensors_close(
|
||||
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_output(self, expected_output_shape=None):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert output[0].shape == expected_output_shape or self.output_shape, (
|
||||
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_outputs_equivalence(self, atol=1e-5, rtol=0):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
t[t != t] = 0
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (list, tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
assert_tensors_close(
|
||||
set_nan_tensor_to_zero(tuple_object),
|
||||
set_nan_tensor_to_zero(dict_object),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Tuple and dict output are not equal",
|
||||
)
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_getattr_is_correct(self, caplog):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
# save some things to test
|
||||
model.dummy_attribute = 5
|
||||
model.register_to_config(test_attribute=5)
|
||||
|
||||
logger_name = "diffusers.models.modeling_utils"
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
caplog.clear()
|
||||
assert hasattr(model, "dummy_attribute")
|
||||
assert getattr(model, "dummy_attribute") == 5
|
||||
assert model.dummy_attribute == 5
|
||||
|
||||
# no warning should be thrown
|
||||
assert caplog.text == ""
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
caplog.clear()
|
||||
assert hasattr(model, "save_pretrained")
|
||||
fn = model.save_pretrained
|
||||
fn_1 = getattr(model, "save_pretrained")
|
||||
|
||||
assert fn == fn_1
|
||||
|
||||
# no warning should be thrown
|
||||
assert caplog.text == ""
|
||||
|
||||
# warning should be thrown for config attributes accessed directly
|
||||
with pytest.warns(FutureWarning):
|
||||
assert model.test_attribute == 5
|
||||
|
||||
with pytest.warns(FutureWarning):
|
||||
assert getattr(model, "test_attribute") == 5
|
||||
|
||||
with pytest.raises(AttributeError) as error:
|
||||
model.does_not_exist
|
||||
|
||||
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be used with an accelerator",
|
||||
)
|
||||
def test_keep_in_fp32_modules(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
if fp32_modules is None or len(fp32_modules) == 0:
|
||||
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
|
||||
|
||||
# Test with float16
|
||||
model.to(torch_device)
|
||||
model.to(torch.float16)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
|
||||
else:
|
||||
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be use for inference with an accelerator",
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules or []
|
||||
|
||||
model.to(dtype).save_pretrained(tmp_path)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
|
||||
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if fp32_modules and any(
|
||||
module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules
|
||||
):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == dtype
|
||||
|
||||
inputs = cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype)
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
output_loaded = model_loaded(**inputs, return_dict=False)[0]
|
||||
|
||||
self._check_dtype_inference_output(output, output_loaded, dtype)
|
||||
|
||||
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
|
||||
"""Check dtype inference output with configurable tolerance."""
|
||||
assert_tensors_close(
|
||||
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
|
||||
)
|
||||
|
||||
@require_accelerator
|
||||
@torch.no_grad()
|
||||
def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_path).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
|
||||
)
|
||||
|
||||
@require_accelerator
|
||||
@torch.no_grad()
|
||||
def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
variant = "fp16"
|
||||
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
|
||||
f"Variant index file {index_filename} should exist"
|
||||
)
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
|
||||
from diffusers.utils import constants
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
# Save original values to restore after test
|
||||
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
|
||||
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||
|
||||
try:
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
# Load without parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_sequential = model_sequential.to(torch_device)
|
||||
|
||||
# Load with parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||
if original_parallel_workers is not None:
|
||||
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
@torch.no_grad()
|
||||
def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
|
||||
model.cpu().save_pretrained(tmp_path)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
|
||||
)
|
||||
@@ -1,166 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_torch_compile,
|
||||
require_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_torch_compile
|
||||
@require_accelerator
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class TorchCompileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing torch.compile functionality on models.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None)
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: compile
|
||||
Use `pytest -m "not compile"` to skip these tests
|
||||
"""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
|
||||
"""Optional list of (height, width) tuples for dynamic shape testing."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
if self.model_class._repeated_blocks is None:
|
||||
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_with_group_offloading(self):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch_device,
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_on_different_shapes(self):
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_works_with_aot(self, tmp_path):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
@@ -1,158 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
|
||||
|
||||
|
||||
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
|
||||
"""
|
||||
Check if IP Adapter processors are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if IP Adapter is correctly set, False otherwise
|
||||
"""
|
||||
for module in model.attn_processors.values():
|
||||
if isinstance(module, processor_cls):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_ip_adapter
|
||||
class IPAdapterTesterMixin:
|
||||
"""
|
||||
Mixin class for testing IP Adapter functionality on models.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Required properties (must be implemented by subclasses):
|
||||
- ip_adapter_processor_cls: The IP Adapter processor class to use
|
||||
|
||||
Required methods (must be implemented by subclasses):
|
||||
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
|
||||
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: ip_adapter
|
||||
Use `pytest -m "not ip_adapter"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_load_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
||||
"IP Adapter processors not set correctly"
|
||||
)
|
||||
|
||||
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
||||
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with IP Adapter enabled"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
|
||||
)
|
||||
def test_ip_adapter_scale(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
||||
|
||||
# Test scale = 0.0 (no effect)
|
||||
model.set_ip_adapter_scale(0.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Test scale = 1.0 (full effect)
|
||||
model.set_ip_adapter_scale(1.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should differ with different scales
|
||||
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with different IP Adapter scales"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
|
||||
)
|
||||
def test_unload_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Save original processors
|
||||
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
|
||||
|
||||
# Unload IP adapter
|
||||
model.unload_ip_adapter()
|
||||
|
||||
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
||||
"IP Adapter should be unloaded"
|
||||
)
|
||||
|
||||
# Verify processors are restored
|
||||
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
assert original_processors == current_processors, "Processors should be restored after unload"
|
||||
@@ -1,555 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_lora,
|
||||
is_torch_compile,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if LoRA layers are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if LoRA is correctly set, False otherwise
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_lora
|
||||
@require_peft_backend
|
||||
class LoraTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA/PEFT functionality on models.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: lora
|
||||
Use `pytest -m "not lora"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=atol, rtol=rtol), (
|
||||
"Output should differ with LoRA enabled"
|
||||
)
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
|
||||
"LoRA weights file not created"
|
||||
)
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert_tensors_close(loaded_v, retrieved_v, atol=atol, rtol=rtol, msg=f"Mismatch in LoRA weight {k}")
|
||||
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=atol, rtol=rtol), (
|
||||
"Output should differ with LoRA enabled"
|
||||
)
|
||||
assert_tensors_close(
|
||||
outputs_with_lora,
|
||||
outputs_with_lora_2,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Outputs should match before and after save/load",
|
||||
)
|
||||
|
||||
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
|
||||
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
# Perturb the metadata in the state dict
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||
|
||||
|
||||
@is_lora
|
||||
@is_torch_compile
|
||||
@require_peft_backend
|
||||
@require_peft_version_greater("0.14.0")
|
||||
@require_torch_version_greater("2.7.1")
|
||||
@require_torch_accelerator
|
||||
class LoraHotSwappingForModelTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA hot swapping functionality on models.
|
||||
|
||||
Test that hotswapping does not result in recompilation on the model directly.
|
||||
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
|
||||
and is extensively tested there. The goal of this test is specifically to ensure that
|
||||
hotswapping with diffusers does not require recompilation.
|
||||
|
||||
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
||||
for the analogous PEFT test.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None)
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest marks: lora, torch_compile
|
||||
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
|
||||
"""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
|
||||
"""Optional list of (height, width) tuples for dynamic compilation tests."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def teardown_method(self):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return lora_config
|
||||
|
||||
def _get_linear_module_name_other_than_attn(self, model):
|
||||
linear_names = [
|
||||
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
|
||||
]
|
||||
return linear_names[0]
|
||||
|
||||
def _check_model_hotswap(
|
||||
self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None, atol=5e-3, rtol=5e-3
|
||||
):
|
||||
"""
|
||||
Check that hotswapping works on a model.
|
||||
|
||||
Steps:
|
||||
- create 2 LoRA adapters and save them
|
||||
- load the first adapter
|
||||
- hotswap the second adapter
|
||||
- check that the outputs are correct
|
||||
- optionally compile the model
|
||||
- optionally check if recompilations happen on different shapes
|
||||
|
||||
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||
fine.
|
||||
"""
|
||||
different_shapes = self.different_shapes_for_compilation
|
||||
# create 2 adapters with different ranks and alphas
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
alpha0, alpha1 = rank0, rank1
|
||||
max_rank = max([rank0, rank1])
|
||||
if target_modules1 is None:
|
||||
target_modules1 = target_modules0[:]
|
||||
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
|
||||
|
||||
model.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
with torch.inference_mode():
|
||||
torch.manual_seed(0)
|
||||
output0_before = model(**inputs_dict)["sample"]
|
||||
|
||||
model.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
model.set_adapter("adapter1")
|
||||
with torch.inference_mode():
|
||||
torch.manual_seed(0)
|
||||
output1_before = model(**inputs_dict)["sample"]
|
||||
|
||||
# sanity checks:
|
||||
assert not torch.allclose(output0_before, output1_before, atol=atol, rtol=rtol)
|
||||
assert not (output0_before == 0).all()
|
||||
assert not (output1_before == 0).all()
|
||||
|
||||
# save the adapter checkpoints
|
||||
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del model
|
||||
|
||||
# load the first adapter
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if do_compile or (rank0 != rank1):
|
||||
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||
model.enable_lora_hotswap(target_rank=max_rank)
|
||||
|
||||
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
|
||||
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
|
||||
|
||||
with torch.inference_mode():
|
||||
# additionally check if dynamic compilation works.
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert_tensors_close(
|
||||
output0_before, output0_after, atol=atol, rtol=rtol, msg="Output mismatch after loading adapter0"
|
||||
)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert_tensors_close(
|
||||
output1_before,
|
||||
output1_after,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output mismatch after hotswapping to adapter1",
|
||||
)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_model(self, tmp_path, rank0, rank1):
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
pytest.skip("Test only applies to UNet.")
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
pytest.skip("Test only applies to UNet.")
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
|
||||
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
|
||||
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
|
||||
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
|
||||
# block.
|
||||
target_modules = ["to_q"]
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
target_modules.append(self._get_linear_module_name_other_than_attn(model))
|
||||
del model
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
|
||||
different_shapes_for_compilation = self.different_shapes_for_compilation
|
||||
if different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
|
||||
# variable to represent input sizes that are the same. For more details,
|
||||
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=rank0,
|
||||
rank1=rank1,
|
||||
target_modules0=target_modules,
|
||||
)
|
||||
@@ -1,498 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import compute_module_sizes
|
||||
|
||||
from diffusers.utils.testing_utils import _check_safetensors_serialization
|
||||
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
is_cpu_offload,
|
||||
is_group_offload,
|
||||
is_memory,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .common import cast_inputs_to_dtype, check_device_map_is_respected
|
||||
|
||||
|
||||
def require_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_group_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support group offloading.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@is_cpu_offload
|
||||
class CPUOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing CPU offloading functionality.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cpu_offload
|
||||
Use `pytest -m "not cpu_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list[float]:
|
||||
"""List of percentages for splitting model across devices during offloading tests."""
|
||||
return [0.5, 0.7]
|
||||
|
||||
@require_offload_support
|
||||
@torch.no_grad()
|
||||
def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
model.cpu().save_pretrained(str(tmp_path))
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
@torch.no_grad()
|
||||
def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
|
||||
# This errors out because it's missing an offload folder
|
||||
with pytest.raises(ValueError):
|
||||
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||
|
||||
new_model = self.model_class.from_pretrained(
|
||||
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
@torch.no_grad()
|
||||
def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model.cpu().save_pretrained(str(tmp_path))
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0],
|
||||
new_output[0],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with disk offloading (safetensors)",
|
||||
)
|
||||
|
||||
|
||||
@is_group_offload
|
||||
class GroupOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing group offloading functionality.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: group_offload
|
||||
Use `pytest -m "not group_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_group_offload_support
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
def test_group_offloading(self, record_stream, atol=1e-5, rtol=0):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
), "Group offloading hook should be set"
|
||||
model.eval()
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading1,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with block-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading2,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with non-blocking block-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading3,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with leaf-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading4,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with leaf-level offloading with stream",
|
||||
)
|
||||
|
||||
@require_group_offload_support
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||
)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
@require_group_offload_support
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5, rtol=0):
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
return "generator" in params
|
||||
|
||||
def _run_forward(model, inputs_dict):
|
||||
accepts_generator = _has_generator_arg(model)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
num_blocks_per_group = None if offload_type == "leaf_level" else 1
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
|
||||
tmpdir = str(tmp_path)
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
if offload_type != "leaf_level":
|
||||
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||
module=model,
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||
elif missing_files:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with disk-based group offloading",
|
||||
)
|
||||
|
||||
|
||||
class LayerwiseCastingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing layerwise dtype casting for memory optimization.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_layerwise_casting_memory(self):
|
||||
MB_TOLERANCE = 0.2
|
||||
LEAST_COMPUTE_CAPABILITY = 8.0
|
||||
|
||||
def reset_memory_stats():
|
||||
gc.collect()
|
||||
backend_synchronize(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
def get_memory_usage(storage_dtype, compute_dtype):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
reset_memory_stats()
|
||||
model(**inputs_dict)
|
||||
model_memory_footprint = model.get_memory_footprint()
|
||||
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
|
||||
|
||||
return model_memory_footprint, peak_inference_memory_allocated_mb
|
||||
|
||||
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
|
||||
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
|
||||
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
|
||||
torch.float8_e4m3fn, torch.bfloat16
|
||||
)
|
||||
|
||||
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
|
||||
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
|
||||
"Memory footprint should decrease with lower precision storage"
|
||||
)
|
||||
|
||||
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
|
||||
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
|
||||
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
|
||||
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
|
||||
"Peak memory should be lower with bf16 compute on newer GPUs"
|
||||
)
|
||||
|
||||
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
||||
# bytes. This only happens for some models, so we allow a small tolerance.
|
||||
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
||||
assert (
|
||||
fp8_e4m3_fp32_max_memory < fp32_max_memory
|
||||
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||
), "Peak memory should be lower or within tolerance with fp8 storage"
|
||||
|
||||
def test_layerwise_casting_training(self):
|
||||
def test_fn(storage_dtype, compute_dtype):
|
||||
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
|
||||
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
model.train()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
|
||||
noise = cast_inputs_to_dtype(noise, torch.float32, compute_dtype)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
test_fn(torch.float16, torch.float32)
|
||||
test_fn(torch.float8_e4m3fn, torch.float32)
|
||||
test_fn(torch.float8_e5m2, torch.float32)
|
||||
test_fn(torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
|
||||
@is_memory
|
||||
@require_accelerator
|
||||
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
|
||||
"""
|
||||
Combined mixin class for all memory optimization tests including CPU/disk offloading,
|
||||
group offloading, and layerwise dtype casting.
|
||||
|
||||
This mixin inherits from:
|
||||
- CPUOffloadTesterMixin: CPU and disk offloading tests
|
||||
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
|
||||
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: memory
|
||||
Use `pytest -m "not memory"` to skip these tests
|
||||
"""
|
||||
@@ -1,128 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
require_torch_multi_accelerator,
|
||||
)
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
|
||||
"""Worker function for context parallel testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
# Create model
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs_on_device[key] = value.to(device)
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
# Run forward pass
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
# Only rank 0 reports results
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["output_shape"] = list(output.shape)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_inference(self, cp_type):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_worker,
|
||||
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,272 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_single_file,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .common import check_device_map_is_respected
|
||||
|
||||
|
||||
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
|
||||
"""Download a single file checkpoint from the Hub to a temporary directory."""
|
||||
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
|
||||
return path
|
||||
|
||||
|
||||
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
|
||||
"""Download diffusers config files (excluding weights) from a repository."""
|
||||
path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_patterns=[
|
||||
"**/*.ckpt",
|
||||
"*.ckpt",
|
||||
"**/*.bin",
|
||||
"*.bin",
|
||||
"**/*.pt",
|
||||
"*.pt",
|
||||
"**/*.safetensors",
|
||||
"*.safetensors",
|
||||
],
|
||||
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||
local_dir=tmpdir,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@is_single_file
|
||||
class SingleFileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing single file loading for models.
|
||||
|
||||
Required properties (must be implemented by subclasses):
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
|
||||
Optional properties:
|
||||
- torch_dtype: torch dtype to use for testing (default: None)
|
||||
- alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None)
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder)
|
||||
|
||||
Pytest mark: single_file
|
||||
Use `pytest -m "not single_file"` to skip these tests
|
||||
"""
|
||||
|
||||
# ==================== Required Properties ====================
|
||||
|
||||
@property
|
||||
def ckpt_path(self) -> str:
|
||||
"""Path or Hub path to the single file checkpoint. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `ckpt_path` property.")
|
||||
|
||||
# ==================== Optional Properties ====================
|
||||
|
||||
@property
|
||||
def torch_dtype(self) -> torch.dtype | None:
|
||||
"""torch dtype to use for single file testing."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def alternate_ckpt_paths(self) -> list[str] | None:
|
||||
"""List of alternate checkpoint paths for variant testing."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
|
||||
single_file_kwargs = {"device": torch_device}
|
||||
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading: "
|
||||
f"pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
pretrained_kwargs = {"device_map": str(torch_device), **self.pretrained_model_kwargs}
|
||||
single_file_kwargs = {"device": torch_device}
|
||||
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
# Load pretrained model, get state dict on CPU, then free GPU memory
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Load single file model, get state dict on CPU
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
state_dict_single_file = {k: v.cpu() for k, v in model_single_file.state_dict().items()}
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||
"Model parameters keys differ between pretrained and single file loading. "
|
||||
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
|
||||
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
param = state_dict[key]
|
||||
param_single_file = state_dict_single_file[key]
|
||||
|
||||
assert param.shape == param_single_file.shape, (
|
||||
f"Parameter shape mismatch for {key}: "
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert torch.equal(param, param_single_file), f"Parameter values differ for {key}"
|
||||
|
||||
def test_single_file_loading_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs.update(self.pretrained_model_kwargs)
|
||||
|
||||
# Load with config parameter
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
|
||||
)
|
||||
|
||||
# Load pretrained for comparison
|
||||
pretrained_kwargs = {**self.pretrained_model_kwargs}
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
|
||||
# Compare configs
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs.update(self.pretrained_model_kwargs)
|
||||
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
|
||||
def test_single_file_loading_dtype(self):
|
||||
for dtype in [torch.float32, torch.float16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
|
||||
|
||||
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
|
||||
|
||||
# Cleanup
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_checkpoint_variant_loading(self):
|
||||
if not self.alternate_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_loading_with_device_map(self):
|
||||
single_file_kwargs = {"device_map": torch_device}
|
||||
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, "Failed to load model with device_map"
|
||||
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute when loaded with device_map"
|
||||
assert model.hf_device_map is not None, "hf_device_map should not be None when loaded with device_map"
|
||||
check_device_map_is_respected(model, model.hf_device_map)
|
||||
@@ -1,220 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_training,
|
||||
require_torch_accelerator_with_training,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_training
|
||||
@require_torch_accelerator_with_training
|
||||
class TrainingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing training functionality on models.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- output_shape: Tuple defining the expected output shape
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: training
|
||||
Use `pytest -m "not training"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
def test_training_with_ema(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
ema_model = EMAModel(model.parameters())
|
||||
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model.parameters())
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self, expected_set=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if expected_set is None:
|
||||
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
for submodule in model.modules():
|
||||
if hasattr(submodule, "gradient_checkpointing"):
|
||||
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
|
||||
modules_with_gc_enabled[submodule.__class__.__name__] = True
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set, (
|
||||
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
|
||||
)
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
|
||||
|
||||
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if skip is None:
|
||||
skip = set()
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
assert (loss - loss_2).abs() < loss_tolerance, (
|
||||
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
|
||||
)
|
||||
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
if param.grad is None:
|
||||
continue
|
||||
|
||||
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
|
||||
f"Gradient mismatch for {name}"
|
||||
)
|
||||
|
||||
def test_mixed_precision_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# Test with float16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Test with bfloat16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
model.zero_grad()
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
@@ -13,52 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# TODO: This standalone function maintains backward compatibility with pipeline tests
|
||||
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
|
||||
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
|
||||
def create_flux_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
@@ -68,7 +39,7 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
|
||||
joint_attention_dim = model.config["joint_attention_dim"]
|
||||
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
||||
sd = FluxIPAdapterAttnProcessor(
|
||||
sd = FluxIPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
@@ -79,8 +50,11 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=(
|
||||
@@ -101,45 +75,57 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
)
|
||||
|
||||
del sd
|
||||
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return FluxTransformer2DModel
|
||||
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-flux-pipe"
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.9]
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
"""Return Flux model initialization arguments."""
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -151,40 +137,11 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
"""Test that deprecated 3D img_ids and txt_ids still work."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -205,228 +162,63 @@ class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux Transformer."""
|
||||
# The test exists for cases like
|
||||
# https://github.com/huggingface/diffusers/issues/11874
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_exclude_modules(self):
|
||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
lora_rank = 4
|
||||
target_module = "single_transformer_blocks.0.proj_out"
|
||||
adapter_name = "foo"
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return FluxIPAdapterAttnProcessor
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
torch.manual_seed(0)
|
||||
# Create dummy image embeds for IP adapter
|
||||
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
|
||||
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
|
||||
return create_flux_ip_adapter_state_dict(model)
|
||||
|
||||
|
||||
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
state_dict = model.state_dict()
|
||||
target_mod_shape = state_dict[f"{target_module}.weight"].shape
|
||||
lora_state_dict = {
|
||||
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
|
||||
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
|
||||
}
|
||||
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
|
||||
config = LoraConfig(
|
||||
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
|
||||
)
|
||||
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
|
||||
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
|
||||
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
|
||||
|
||||
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 32
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
|
||||
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
@property
|
||||
def ckpt_path(self):
|
||||
return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
@property
|
||||
def alternate_ckpt_paths(self):
|
||||
return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
"""Quanto quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-flux-transformer"
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {}
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 4096, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(1, 768), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
"img_ids": randn_tensor((4096, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
|
||||
"txt_ids": randn_tensor((512, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
|
||||
"""Quanto + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
|
||||
"""TorchAO + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 4096, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(1, 768), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
"img_ids": randn_tensor((4096, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
|
||||
"txt_ids": randn_tensor((512, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
"""ModelOpt quantization tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
|
||||
"""ModelOpt + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes")
|
||||
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
|
||||
"""BitsAndBytes + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||
"""FirstBlockCache tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
|
||||
"""FasterCache tests for Flux Transformer."""
|
||||
|
||||
# Flux is guidance distilled, so we can test at model level without CFG batch handling
|
||||
FASTER_CACHE_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||
"tensor_format": "BCHW",
|
||||
"is_guidance_distilled": True,
|
||||
}
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
@@ -13,86 +13,49 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return QwenImageTransformer2DModel
|
||||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
def output_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8 # Must be divisible by 2 for context parallel tests
|
||||
sequence_length = 7
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
@@ -107,12 +70,29 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
@@ -124,24 +104,24 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
)
|
||||
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
assert isinstance(rope_text_seq_len, int)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2 # Only 2 True values
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
@@ -153,22 +133,21 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
)
|
||||
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
assert int(per_sample_len2.max().item()) == 7
|
||||
assert normalized_mask2.sum().item() == 4 # 4 True values
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(rope_text_seq_len_none, int)
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
@@ -181,22 +160,21 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
assert int(per_sample_len.max().item()) == 5
|
||||
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(inferred_rope_len, int)
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
@@ -208,24 +186,18 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
# Test that deprecation warning is raised
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
# Verify a FutureWarning was raised
|
||||
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
|
||||
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(future_warnings[0].message)
|
||||
assert "txt_seq_lens" in warning_message
|
||||
assert "deprecated" in warning_message
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
@@ -236,7 +208,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
@@ -248,11 +220,11 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 8
|
||||
text_seq_len = 7
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
@@ -290,69 +262,24 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
assert output.sample.shape[1] == hidden_states.shape[1]
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
|
||||
|
||||
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for QwenImage Transformer."""
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for QwenImage Transformer."""
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8 # Must be divisible by 2 for context parallel tests
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
@@ -373,13 +300,13 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
# Keep the all-ones mask
|
||||
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
@@ -393,8 +320,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
@@ -415,16 +342,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
|
||||
|
||||
|
||||
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for QwenImage Transformer."""
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
from transformers import AutoConfig, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -89,7 +89,8 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
|
||||
@@ -41,7 +41,8 @@ class ChromaPipelineFastTests(
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
|
||||
@@ -42,7 +42,8 @@ class ChromaImg2ImgPipelineFastTests(
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import unittest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionConfig,
|
||||
@@ -71,7 +72,8 @@ class ChronoEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
torch.manual_seed(0)
|
||||
# TODO: impl FlowDPMSolverMultistepScheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
|
||||
|
||||
@@ -117,7 +117,8 @@ class CogVideoXPipelineFastTests(
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler
|
||||
|
||||
@@ -104,7 +104,8 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
|
||||
from diffusers.utils import load_image
|
||||
@@ -113,7 +113,8 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler
|
||||
|
||||
@@ -99,7 +99,8 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
|
||||
|
||||
@@ -89,7 +89,8 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = CogVideoXDDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -108,7 +108,7 @@ class CogView4PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "",
|
||||
"negative_prompt": "bad",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler
|
||||
from diffusers.utils import load_image
|
||||
@@ -122,7 +122,8 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
from transformers import AutoConfig, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -97,7 +97,8 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -13,9 +13,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
@@ -70,7 +68,8 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,15 +3,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# torch_device, # {{ edit_1 }} Removed unused import
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -22,11 +14,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -85,7 +73,8 @@ class FluxControlNetInpaintPipelineTests(unittest.TestCase, PipelineTesterMixin)
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BertModel, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, BertModel, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -96,7 +96,10 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
|
||||
scheduler = DDPMScheduler()
|
||||
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -17,7 +17,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -28,10 +35,7 @@ from diffusers import (
|
||||
from diffusers.models import SD3ControlNetModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -103,7 +107,8 @@ class StableDiffusion3ControlInpaintNetPipelineFastTests(unittest.TestCase, Pipe
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_3 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
@@ -19,7 +19,14 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -118,7 +125,8 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_3 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
|
||||
|
||||
@@ -107,7 +107,8 @@ class CosmosTextToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
rho=7.0,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
@@ -95,7 +95,8 @@ class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -21,7 +21,7 @@ import unittest
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
@@ -96,7 +96,8 @@ class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -21,7 +21,7 @@ import unittest
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCosmos, CosmosTransformer3DModel, CosmosVideoToWorldPipeline, EDMEulerScheduler
|
||||
|
||||
@@ -108,7 +108,8 @@ class CosmosVideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
rho=7.0,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -2,7 +2,7 @@ import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import DDPMScheduler, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnAddedKVProcessor
|
||||
@@ -18,7 +18,8 @@ from ..test_pipelines_common import to_np
|
||||
class IFPipelineTesterMixin:
|
||||
def _get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
@@ -75,7 +76,8 @@ class IFPipelineTesterMixin:
|
||||
|
||||
def _get_superresolution_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -18,9 +18,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
IFPipeline,
|
||||
)
|
||||
from diffusers import IFPipeline
|
||||
from diffusers.models.attention_processor import AttnAddedKVProcessor
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -91,7 +91,8 @@ class FluxPipelineFastTests(
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
|
||||
|
||||
@@ -53,7 +53,8 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -57,7 +57,8 @@ class FluxControlImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -58,7 +58,8 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxFillPipeline, FluxTransformer2DModel
|
||||
|
||||
@@ -58,7 +58,8 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel
|
||||
|
||||
@@ -55,7 +55,8 @@ class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxI
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxInpaintPipeline, FluxTransformer2DModel
|
||||
|
||||
@@ -55,7 +55,8 @@ class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxI
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -79,7 +79,8 @@ class FluxKontextPipelineFastTests(
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -79,7 +79,8 @@ class FluxKontextInpaintPipelineFastTests(
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel
|
||||
from diffusers.utils import is_transformers_version
|
||||
@@ -57,7 +57,8 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
glm_config = GlmImageConfig(
|
||||
|
||||
@@ -18,6 +18,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModelWithProjection,
|
||||
@@ -94,7 +95,8 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_3 = T5EncoderModel(config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_4 = LlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
|
||||
@@ -149,7 +151,7 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(generated_image.shape, (128, 128, 3))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.4507, 0.5256, 0.4205, 0.5791, 0.4848, 0.4831, 0.4443, 0.5107, 0.6586, 0.3163, 0.7318, 0.5933, 0.6252, 0.5512, 0.5357, 0.5983])
|
||||
expected_slice = np.array([0.4501, 0.5256, 0.4207, 0.5783, 0.4842, 0.4833, 0.4441, 0.5112, 0.6587, 0.3169, 0.7308, 0.5927, 0.6251, 0.5509, 0.5355, 0.5969])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
|
||||
@@ -233,7 +233,7 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
self.assertEqual(generated_video.shape, (5, 3, 16, 16))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.444, 0.479, 0.4485, 0.5752, 0.3539, 0.1548, 0.2706, 0.3593, 0.5323, 0.6635, 0.6795, 0.5255, 0.5091, 0.345, 0.4276, 0.4128])
|
||||
expected_slice = torch.tensor([0.4441, 0.4790, 0.4485, 0.5748, 0.3539, 0.1553, 0.2707, 0.3594, 0.5331, 0.6645, 0.6799, 0.5257, 0.5092, 0.3450, 0.4276, 0.4127])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
|
||||
@@ -15,7 +15,14 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, Qwen2_5_VLTextConfig, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
ByT5Tokenizer,
|
||||
Qwen2_5_VLTextConfig,
|
||||
Qwen2_5_VLTextModel,
|
||||
Qwen2Tokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
@@ -114,7 +121,8 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
tokenizer_2 = ByT5Tokenizer()
|
||||
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BertModel, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, BertModel, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
|
||||
|
||||
@@ -74,7 +74,9 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
scheduler = DDPMScheduler()
|
||||
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoPipelineForImage2Image,
|
||||
@@ -108,7 +108,8 @@ class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
torch.manual_seed(0)
|
||||
movq = self.dummy_movq
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoPipelineForImage2Image,
|
||||
@@ -119,7 +119,8 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
torch.manual_seed(0)
|
||||
movq = self.dummy_movq
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -109,7 +109,8 @@ class LattePipelineFastTests(
|
||||
vae = AutoencoderKL()
|
||||
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
|
||||
|
||||
@@ -88,7 +88,8 @@ class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unit
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
@@ -92,7 +92,8 @@ class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
@@ -91,7 +91,8 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
|
||||
|
||||
@@ -89,7 +89,8 @@ class MochiPipelineFastTests(
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BertModel, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, BertModel, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -67,7 +67,9 @@ class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
scheduler = DDPMScheduler()
|
||||
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
@@ -80,7 +80,8 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
vae = AutoencoderKL()
|
||||
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
|
||||
@@ -3,7 +3,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -73,7 +80,9 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_3 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
@@ -5,7 +5,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -84,7 +91,9 @@ class StableDiffusion3PAGImg2ImgPipelineFastTests(unittest.TestCase, PipelineTes
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_3 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -77,7 +77,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
vae = AutoencoderKL()
|
||||
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -83,7 +83,10 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
vae = AutoencoderKL()
|
||||
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208])
|
||||
expected_slice = torch.tensor([0.5646, 0.6369, 0.6019, 0.5640, 0.5830, 0.5520, 0.5717, 0.6315, 0.4167, 0.3563, 0.5640, 0.4849, 0.4961, 0.5237, 0.4084, 0.5014])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
|
||||
@@ -163,7 +163,7 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
|
||||
expected_slice = torch.tensor([0.5640, 0.6350, 0.6003, 0.5606, 0.5801, 0.5502, 0.5757, 0.6388, 0.4174, 0.3590, 0.5647, 0.4891, 0.4975, 0.5256, 0.4088, 0.4991])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
|
||||
@@ -164,7 +164,7 @@ class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
|
||||
expected_slice = torch.tensor([0.5640, 0.6339, 0.5997, 0.5607, 0.5799, 0.5496, 0.5760, 0.6393, 0.4172, 0.3595, 0.5655, 0.4896, 0.4971, 0.5255, 0.4088, 0.4987])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
|
||||
@@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
@@ -68,7 +68,8 @@ class SkyReelsV2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(flow_shift=8.0, use_flow_sigmas=True)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
@@ -68,7 +68,8 @@ class SkyReelsV2DiffusionForcingPipelineFastTests(PipelineTesterMixin, unittest.
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(flow_shift=8.0, use_flow_sigmas=True)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -18,6 +18,7 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
@@ -68,7 +69,8 @@ class SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests(PipelineTesterMixi
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -159,7 +161,8 @@ class SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests(SkyReelsV2Diffusio
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
@@ -70,7 +70,8 @@ class SkyReelsV2DiffusionForcingVideoToVideoPipelineFastTests(PipelineTesterMixi
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -18,6 +18,7 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionConfig,
|
||||
@@ -71,7 +72,8 @@ class SkyReelsV2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -19,10 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
from transformers import AutoConfig, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderOobleck,
|
||||
@@ -111,7 +108,8 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
t5_repo_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration"
|
||||
text_encoder = T5EncoderModel.from_pretrained(t5_repo_id)
|
||||
config = AutoConfig.from_pretrained(t5_repo_id)
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = T5Tokenizer.from_pretrained(t5_repo_id, truncation=True, model_max_length=25)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -3,7 +3,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
|
||||
|
||||
@@ -72,7 +79,9 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_3 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
@@ -4,7 +4,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -73,7 +80,9 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_3 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
@@ -3,7 +3,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -73,7 +80,9 @@ class StableDiffusion3InpaintPipelineFastTests(PipelineLatentTesterMixin, unitte
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
torch.manual_seed(0)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_3 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, VisualClozePipeline
|
||||
@@ -77,7 +77,8 @@ class VisualClozePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
@@ -79,7 +79,8 @@ class VisualClozeGenerationPipelineFastTests(unittest.TestCase, PipelineTesterMi
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder_2 = T5EncoderModel(config)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user