Compare commits

...

7 Commits

Author SHA1 Message Date
Dhruv Nair
4e2c346992 update 2025-05-13 16:08:01 +02:00
Dhruv Nair
5de8cb7ddb update 2025-05-13 15:51:09 +02:00
Dhruv Nair
406a656d9a update 2025-05-13 14:10:56 +02:00
DN6
971794454f update 2025-05-13 15:44:14 +05:30
DN6
4436f54a80 update 2025-05-13 14:56:12 +05:30
DN6
2a19302e65 update 2025-05-13 14:50:13 +05:30
DN6
721375bb81 update 2025-05-13 14:48:20 +05:30
6 changed files with 67 additions and 5 deletions

View File

@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## Loading GGUF quantized checkpoints for HiDream-I1
GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`
```python
import torch
from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
transformer = HiDreamImageTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16
)
```
## HiDreamImageTransformer2DModel
[[autodoc]] HiDreamImageTransformer2DModel

View File

@@ -31,6 +31,7 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
@@ -133,6 +134,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
"HiDreamImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
}

View File

@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
}
# Use to configure model sample size when original config is provided
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
else:
model_type = "v1"
@@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
converted_state_dict[key] = value
return converted_state_dict
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
return checkpoint

View File

@@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
@@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module):
)
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]

View File

@@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
... output_hidden_states=True,

View File

@@ -12,6 +12,7 @@ from diffusers import (
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
@@ -549,3 +550,30 @@ class FluxControlLoRAGGUFTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
torch_dtype = torch.bfloat16
model_cls = HiDreamImageTransformer2DModel
expected_memory_use_in_gb = 8
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states_t5": torch.randn(
(1, 128, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"encoder_hidden_states_llama3": torch.randn(
(32, 1, 128, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"pooled_embeds": torch.randn(
(1, 2048),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
}