mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
11 Commits
modular-do
...
support-di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d10207e0c9 | ||
|
|
1c788a9234 | ||
|
|
410ea4455f | ||
|
|
a85f597ede | ||
|
|
0dd7817cb8 | ||
|
|
ace1d4a33b | ||
|
|
896316216b | ||
|
|
3f67ed08b4 | ||
|
|
bf1ac4af24 | ||
|
|
4ced879930 | ||
|
|
e5ca3a61b4 |
@@ -153,9 +153,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|||||||
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||||
"default_subfolder": "transformer",
|
"default_subfolder": "transformer",
|
||||||
},
|
},
|
||||||
|
"QwenImageTransformer2DModel": {
|
||||||
|
"checkpoint_mapping_fn": lambda x: x,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
||||||
|
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
||||||
|
|
||||||
|
|
||||||
def _get_single_file_loadable_mapping_class(cls):
|
def _get_single_file_loadable_mapping_class(cls):
|
||||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||||
@@ -381,19 +389,23 @@ class FromOriginalModelMixin:
|
|||||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||||
diffusers_model_config.update(model_kwargs)
|
diffusers_model_config.update(model_kwargs)
|
||||||
|
|
||||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
|
||||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
|
||||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
|
||||||
)
|
|
||||||
if not diffusers_format_checkpoint:
|
|
||||||
raise SingleFileComponentError(
|
|
||||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||||
with ctx():
|
with ctx():
|
||||||
model = cls.from_config(diffusers_model_config)
|
model = cls.from_config(diffusers_model_config)
|
||||||
|
|
||||||
|
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||||
|
|
||||||
|
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
||||||
|
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||||
|
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
diffusers_format_checkpoint = checkpoint
|
||||||
|
|
||||||
|
if not diffusers_format_checkpoint:
|
||||||
|
raise SingleFileComponentError(
|
||||||
|
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||||
|
)
|
||||||
# Check if `_keep_in_fp32_modules` is not None
|
# Check if `_keep_in_fp32_modules` is not None
|
||||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ if is_accelerate_available():
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
CHECKPOINT_KEY_NAMES = {
|
CHECKPOINT_KEY_NAMES = {
|
||||||
|
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
|
||||||
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||||
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ class GGUFSingleFileTesterMixin:
|
|||||||
|
|
||||||
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||||
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||||
|
diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
model_cls = FluxTransformer2DModel
|
model_cls = FluxTransformer2DModel
|
||||||
expected_memory_use_in_gb = 5
|
expected_memory_use_in_gb = 5
|
||||||
@@ -296,6 +297,16 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
|||||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||||
assert max_diff < 1e-4
|
assert max_diff < 1e-4
|
||||||
|
|
||||||
|
def test_loading_gguf_diffusers_format(self):
|
||||||
|
model = self.model_cls.from_single_file(
|
||||||
|
self.diffusers_ckpt_path,
|
||||||
|
subfolder="transformer",
|
||||||
|
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||||
|
config="black-forest-labs/FLUX.1-dev",
|
||||||
|
)
|
||||||
|
model.to("cuda")
|
||||||
|
model(**self.get_dummy_inputs())
|
||||||
|
|
||||||
|
|
||||||
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||||
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
|
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
|
||||||
|
|||||||
Reference in New Issue
Block a user