mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 19:34:48 +08:00
Compare commits
5 Commits
cache-docs
...
flux-sf-fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71c1893f32 | ||
|
|
50bda8536a | ||
|
|
8c0632aa8e | ||
|
|
8cb29999fc | ||
|
|
982aa5932a |
@@ -605,10 +605,14 @@ def infer_diffusers_model_type(checkpoint):
|
|||||||
if any(
|
if any(
|
||||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||||
):
|
):
|
||||||
if checkpoint["img_in.weight"].shape[1] == 384:
|
if "model.diffusion_model.img_in.weight" in checkpoint:
|
||||||
model_type = "flux-fill"
|
key = "model.diffusion_model.img_in.weight"
|
||||||
|
else:
|
||||||
|
key = "img_in.weight"
|
||||||
|
|
||||||
elif checkpoint["img_in.weight"].shape[1] == 128:
|
if checkpoint[key].shape[1] == 384:
|
||||||
|
model_type = "flux-fill"
|
||||||
|
elif checkpoint[key].shape[1] == 128:
|
||||||
model_type = "flux-depth"
|
model_type = "flux-depth"
|
||||||
else:
|
else:
|
||||||
model_type = "flux-dev"
|
model_type = "flux-dev"
|
||||||
|
|||||||
72
tests/single_file/test_model_flux_transformer_single_file.py
Normal file
72
tests/single_file/test_model_flux_transformer_single_file.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
FluxTransformer2DModel,
|
||||||
|
)
|
||||||
|
from diffusers.utils.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
|
enable_full_determinism,
|
||||||
|
require_torch_accelerator,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
enable_full_determinism()
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_accelerator
|
||||||
|
class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||||
|
model_class = FluxTransformer2DModel
|
||||||
|
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||||
|
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||||
|
|
||||||
|
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def test_single_file_components(self):
|
||||||
|
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||||
|
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||||
|
|
||||||
|
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||||
|
for param_name, param_value in model_single_file.config.items():
|
||||||
|
if param_name in PARAMS_TO_IGNORE:
|
||||||
|
continue
|
||||||
|
assert (
|
||||||
|
model.config[param_name] == param_value
|
||||||
|
), f"{param_name} differs between single file loading and pretrained loading"
|
||||||
|
|
||||||
|
def test_checkpoint_loading(self):
|
||||||
|
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model = self.model_class.from_single_file(ckpt_path)
|
||||||
|
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
Reference in New Issue
Block a user