mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
* Flux IP-Adapter * test cfg * make style * temp remove copied from * fix test * fix test * v2 * fix * make style * temp remove copied from * Apply suggestions from code review Co-authored-by: YiYi Xu <yixu310@gmail.com> * Move encoder_hid_proj to inside FluxTransformer2DModel * merge * separate encode_prompt, add copied from, image_encoder offload * make * fix test * fix * Update src/diffusers/pipelines/flux/pipeline_flux.py * test_flux_prompt_embeds change not needed * true_cfg -> true_cfg_scale * fix merge conflict * test_flux_ip_adapter_inference * add fast test * FluxIPAdapterMixin not test mixin * Update pipeline_flux.py Co-authored-by: YiYi Xu <yixu310@gmail.com> --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
98 lines
3.6 KiB
Python
98 lines
3.6 KiB
Python
import argparse
|
|
from contextlib import nullcontext
|
|
|
|
import safetensors.torch
|
|
from accelerate import init_empty_weights
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
|
|
|
|
|
|
if is_transformers_available():
|
|
from transformers import CLIPVisionModelWithProjection
|
|
|
|
vision = True
|
|
else:
|
|
vision = False
|
|
|
|
"""
|
|
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
|
|
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
|
|
--filename "flux-ip-adapter.safetensors"
|
|
--output_path "flux-ip-adapter-hf/"
|
|
"""
|
|
|
|
|
|
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
|
parser.add_argument("--filename", default="flux.safetensors", type=str)
|
|
parser.add_argument("--checkpoint_path", default=None, type=str)
|
|
parser.add_argument("--output_path", type=str)
|
|
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
def load_original_checkpoint(args):
|
|
if args.original_state_dict_repo_id is not None:
|
|
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
|
|
elif args.checkpoint_path is not None:
|
|
ckpt_path = args.checkpoint_path
|
|
else:
|
|
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
|
|
|
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
|
return original_state_dict
|
|
|
|
|
|
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
|
|
converted_state_dict = {}
|
|
|
|
# image_proj
|
|
## norm
|
|
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
|
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
|
## proj
|
|
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
|
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
|
|
|
# double transformer blocks
|
|
for i in range(num_layers):
|
|
block_prefix = f"ip_adapter.{i}."
|
|
# to_k_ip
|
|
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
|
|
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
|
|
)
|
|
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
|
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
|
|
)
|
|
# to_v_ip
|
|
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
|
|
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
|
|
)
|
|
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
|
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
|
|
)
|
|
|
|
return converted_state_dict
|
|
|
|
|
|
def main(args):
|
|
original_ckpt = load_original_checkpoint(args)
|
|
|
|
num_layers = 19
|
|
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
|
|
|
|
print("Saving Flux IP-Adapter in Diffusers format.")
|
|
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
|
|
|
|
if vision:
|
|
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
|
|
model.save_pretrained(f"{args.output_path}/image_encoder")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(args)
|