mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 00:14:23 +08:00
Compare commits
2 Commits
custom-blo
...
lumina-sf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0cd95a601e | ||
|
|
9f62ef23c6 |
@@ -26,6 +26,56 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using Single File loading with Lumina Image 2.0
|
||||
|
||||
Single file loading for Lumina Image 2.0 is available for the `Lumina2Transformer2DModel`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
|
||||
transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
ckpt_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
image = pipe(
|
||||
"a cat holding a sign that says hello",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
image.save("lumina-single-file.png")
|
||||
|
||||
```
|
||||
|
||||
## Using GGUF Quantized Checkpoints with Lumina Image 2.0
|
||||
|
||||
GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`
|
||||
|
||||
```python
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
|
||||
transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
image = pipe(
|
||||
"a cat holding a sign that says hello",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
image.save("lumina-gguf.png")
|
||||
```
|
||||
|
||||
## Lumina2Text2ImgPipeline
|
||||
|
||||
[[autodoc]] Lumina2Text2ImgPipeline
|
||||
|
||||
@@ -34,6 +34,7 @@ from .single_file_utils import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
@@ -111,6 +112,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Lumina2Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -116,6 +116,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -174,6 +175,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
||||
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -657,6 +659,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
):
|
||||
model_type = "instruct-pix2pix"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
||||
model_type = "lumina2"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -2798,3 +2803,75 @@ def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_lumina2_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
# Original Lumina-Image-2 has an extra norm paramter that is unused
|
||||
# We just remove it here
|
||||
checkpoint.pop("norm_final.weight", None)
|
||||
|
||||
# Comfy checkpoints add this prefix
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
LUMINA_KEY_MAP = {
|
||||
"cap_embedder": "time_caption_embed.caption_embedder",
|
||||
"t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1",
|
||||
"t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2",
|
||||
"attention": "attn",
|
||||
".out.": ".to_out.0.",
|
||||
"k_norm": "norm_k",
|
||||
"q_norm": "norm_q",
|
||||
"w1": "linear_1",
|
||||
"w2": "linear_2",
|
||||
"w3": "linear_3",
|
||||
"adaLN_modulation.1": "norm1.linear",
|
||||
}
|
||||
ATTENTION_NORM_MAP = {
|
||||
"attention_norm1": "norm1.norm",
|
||||
"attention_norm2": "norm2",
|
||||
}
|
||||
CONTEXT_REFINER_MAP = {
|
||||
"context_refiner.0.attention_norm1": "context_refiner.0.norm1",
|
||||
"context_refiner.0.attention_norm2": "context_refiner.0.norm2",
|
||||
"context_refiner.1.attention_norm1": "context_refiner.1.norm1",
|
||||
"context_refiner.1.attention_norm2": "context_refiner.1.norm2",
|
||||
}
|
||||
FINAL_LAYER_MAP = {
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.linear": "norm_out.linear_2",
|
||||
}
|
||||
|
||||
def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
|
||||
q_dim = 2304
|
||||
k_dim = v_dim = 768
|
||||
|
||||
to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0)
|
||||
|
||||
return {
|
||||
diffusers_key.replace("qkv", "to_q"): to_q,
|
||||
diffusers_key.replace("qkv", "to_k"): to_k,
|
||||
diffusers_key.replace("qkv", "to_v"): to_v,
|
||||
}
|
||||
|
||||
for key in keys:
|
||||
diffusers_key = key
|
||||
for k, v in CONTEXT_REFINER_MAP.items():
|
||||
diffusers_key = diffusers_key.replace(k, v)
|
||||
for k, v in FINAL_LAYER_MAP.items():
|
||||
diffusers_key = diffusers_key.replace(k, v)
|
||||
for k, v in ATTENTION_NORM_MAP.items():
|
||||
diffusers_key = diffusers_key.replace(k, v)
|
||||
for k, v in LUMINA_KEY_MAP.items():
|
||||
diffusers_key = diffusers_key.replace(k, v)
|
||||
|
||||
if "qkv" in diffusers_key:
|
||||
converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key))
|
||||
else:
|
||||
converted_state_dict[diffusers_key] = checkpoint.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention
|
||||
@@ -333,7 +334,7 @@ class Lumina2RotaryPosEmbed(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
Lumina2NextDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
|
||||
74
tests/single_file/test_lumina2_transformer.py
Normal file
74
tests/single_file/test_lumina2_transformer.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
Lumina2Transformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
|
||||
model_class = Lumina2Transformer2DModel
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
|
||||
alternate_keys_ckpt_paths = [
|
||||
"https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
|
||||
]
|
||||
|
||||
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
torch.cuda.empty_cache()
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
Reference in New Issue
Block a user