mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 21:44:27 +08:00
Compare commits
5 Commits
dynamic-te
...
flux-sf-fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71c1893f32 | ||
|
|
50bda8536a | ||
|
|
8c0632aa8e | ||
|
|
8cb29999fc | ||
|
|
982aa5932a |
@@ -605,10 +605,14 @@ def infer_diffusers_model_type(checkpoint):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
):
|
||||
if checkpoint["img_in.weight"].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
if "model.diffusion_model.img_in.weight" in checkpoint:
|
||||
key = "model.diffusion_model.img_in.weight"
|
||||
else:
|
||||
key = "img_in.weight"
|
||||
|
||||
elif checkpoint["img_in.weight"].shape[1] == 128:
|
||||
if checkpoint[key].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
elif checkpoint[key].shape[1] == 128:
|
||||
model_type = "flux-depth"
|
||||
else:
|
||||
model_type = "flux-dev"
|
||||
|
||||
72
tests/single_file/test_model_flux_transformer_single_file.py
Normal file
72
tests/single_file/test_model_flux_transformer_single_file.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
torch.cuda.empty_cache()
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
Reference in New Issue
Block a user