mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 22:44:38 +08:00
Compare commits
7 Commits
add-uv-scr
...
hidream-si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e2c346992 | ||
|
|
5de8cb7ddb | ||
|
|
406a656d9a | ||
|
|
971794454f | ||
|
|
4436f54a80 | ||
|
|
2a19302e65 | ||
|
|
721375bb81 |
@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
|
|||||||
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
|
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Loading GGUF quantized checkpoints for HiDream-I1
|
||||||
|
|
||||||
|
GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel
|
||||||
|
|
||||||
|
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
|
||||||
|
transformer = HiDreamImageTransformer2DModel.from_single_file(
|
||||||
|
ckpt_path,
|
||||||
|
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## HiDreamImageTransformer2DModel
|
## HiDreamImageTransformer2DModel
|
||||||
|
|
||||||
[[autodoc]] HiDreamImageTransformer2DModel
|
[[autodoc]] HiDreamImageTransformer2DModel
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from .single_file_utils import (
|
|||||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||||
convert_controlnet_checkpoint,
|
convert_controlnet_checkpoint,
|
||||||
convert_flux_transformer_checkpoint_to_diffusers,
|
convert_flux_transformer_checkpoint_to_diffusers,
|
||||||
|
convert_hidream_transformer_to_diffusers,
|
||||||
convert_hunyuan_video_transformer_to_diffusers,
|
convert_hunyuan_video_transformer_to_diffusers,
|
||||||
convert_ldm_unet_checkpoint,
|
convert_ldm_unet_checkpoint,
|
||||||
convert_ldm_vae_checkpoint,
|
convert_ldm_vae_checkpoint,
|
||||||
@@ -133,6 +134,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|||||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||||
"default_subfolder": "vae",
|
"default_subfolder": "vae",
|
||||||
},
|
},
|
||||||
|
"HiDreamImageTransformer2DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
|
|||||||
],
|
],
|
||||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||||
|
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||||
}
|
}
|
||||||
|
|
||||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||||
@@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|||||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||||
|
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use to configure model sample size when original config is provided
|
# Use to configure model sample size when original config is provided
|
||||||
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
|
|||||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||||
model_type = "wan-t2v-14B"
|
model_type = "wan-t2v-14B"
|
||||||
|
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
||||||
|
model_type = "hidream"
|
||||||
else:
|
else:
|
||||||
model_type = "v1"
|
model_type = "v1"
|
||||||
|
|
||||||
@@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
|
|||||||
converted_state_dict[key] = value
|
converted_state_dict[key] = value
|
||||||
|
|
||||||
return converted_state_dict
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
for k in keys:
|
||||||
|
if "model.diffusion_model." in k:
|
||||||
|
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import PeftAdapterMixin
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||||
from ...models.modeling_utils import ModelMixin
|
from ...models.modeling_utils import ModelMixin
|
||||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||||
@@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||||
_supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
|
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
|
||||||
|
|
||||||
|
|||||||
@@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """
|
|||||||
Examples:
|
Examples:
|
||||||
```py
|
```py
|
||||||
>>> import torch
|
>>> import torch
|
||||||
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
|
>>> from diffusers import HiDreamImagePipeline
|
||||||
|
|
||||||
|
|
||||||
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||||
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||||
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
... output_hidden_states=True,
|
... output_hidden_states=True,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from diffusers import (
|
|||||||
FluxPipeline,
|
FluxPipeline,
|
||||||
FluxTransformer2DModel,
|
FluxTransformer2DModel,
|
||||||
GGUFQuantizationConfig,
|
GGUFQuantizationConfig,
|
||||||
|
HiDreamImageTransformer2DModel,
|
||||||
SD3Transformer2DModel,
|
SD3Transformer2DModel,
|
||||||
StableDiffusion3Pipeline,
|
StableDiffusion3Pipeline,
|
||||||
)
|
)
|
||||||
@@ -549,3 +550,30 @@ class FluxControlLoRAGGUFTests(unittest.TestCase):
|
|||||||
|
|
||||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||||
self.assertTrue(max_diff < 1e-3)
|
self.assertTrue(max_diff < 1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||||
|
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
model_cls = HiDreamImageTransformer2DModel
|
||||||
|
expected_memory_use_in_gb = 8
|
||||||
|
|
||||||
|
def get_dummy_inputs(self):
|
||||||
|
return {
|
||||||
|
"hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||||
|
torch_device, self.torch_dtype
|
||||||
|
),
|
||||||
|
"encoder_hidden_states_t5": torch.randn(
|
||||||
|
(1, 128, 4096),
|
||||||
|
generator=torch.Generator("cpu").manual_seed(0),
|
||||||
|
).to(torch_device, self.torch_dtype),
|
||||||
|
"encoder_hidden_states_llama3": torch.randn(
|
||||||
|
(32, 1, 128, 4096),
|
||||||
|
generator=torch.Generator("cpu").manual_seed(0),
|
||||||
|
).to(torch_device, self.torch_dtype),
|
||||||
|
"pooled_embeds": torch.randn(
|
||||||
|
(1, 2048),
|
||||||
|
generator=torch.Generator("cpu").manual_seed(0),
|
||||||
|
).to(torch_device, self.torch_dtype),
|
||||||
|
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user