mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-20 07:28:13 +08:00
Compare commits
3 Commits
main
...
type-hint-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d90a4dfe57 | ||
|
|
5219182752 | ||
|
|
189491a4f2 |
@@ -31,7 +31,41 @@ from diffusers.modular_pipelines import (
|
||||
WanModularPipeline,
|
||||
)
|
||||
|
||||
from ..testing_utils import nightly, require_torch, slow
|
||||
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
def _create_tiny_model_dir(model_dir):
|
||||
TINY_MODEL_CODE = (
|
||||
"import torch\n"
|
||||
"from diffusers import ModelMixin, ConfigMixin\n"
|
||||
"from diffusers.configuration_utils import register_to_config\n"
|
||||
"\n"
|
||||
"class TinyModel(ModelMixin, ConfigMixin):\n"
|
||||
" @register_to_config\n"
|
||||
" def __init__(self, hidden_size=4):\n"
|
||||
" super().__init__()\n"
|
||||
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
|
||||
"\n"
|
||||
" def forward(self, x):\n"
|
||||
" return self.linear(x)\n"
|
||||
)
|
||||
|
||||
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
|
||||
f.write(TINY_MODEL_CODE)
|
||||
|
||||
config = {
|
||||
"_class_name": "TinyModel",
|
||||
"_diffusers_version": "0.0.0",
|
||||
"auto_map": {"AutoModel": "modeling.TinyModel"},
|
||||
"hidden_size": 4,
|
||||
}
|
||||
with open(os.path.join(model_dir, "config.json"), "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
torch.save(
|
||||
{"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)},
|
||||
os.path.join(model_dir, "diffusion_pytorch_model.bin"),
|
||||
)
|
||||
|
||||
|
||||
class DummyCustomBlockSimple(ModularPipelineBlocks):
|
||||
@@ -341,6 +375,81 @@ class TestModularCustomBlocks:
|
||||
loaded_pipe.update_components(custom_model=custom_model)
|
||||
assert getattr(loaded_pipe, "custom_model", None) is not None
|
||||
|
||||
def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path):
|
||||
"""Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel."""
|
||||
from diffusers import AutoModel
|
||||
|
||||
model_dir = str(tmp_path / "model")
|
||||
os.makedirs(model_dir)
|
||||
_create_tiny_model_dir(model_dir)
|
||||
|
||||
class DtypeTestBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True)]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("output", type_hint=str)]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.output = "test"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
block = DtypeTestBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True)
|
||||
|
||||
assert pipe.model.dtype == torch.float16
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_automodel_type_hint_preserves_device(self, tmp_path):
|
||||
"""Test that ComponentSpec with AutoModel type_hint correctly passes device_map."""
|
||||
from diffusers import AutoModel
|
||||
|
||||
model_dir = str(tmp_path / "model")
|
||||
os.makedirs(model_dir)
|
||||
_create_tiny_model_dir(model_dir)
|
||||
|
||||
class DeviceTestBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True)]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("output", type_hint=str)]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.output = "test"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
block = DeviceTestBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.load_components(device_map=torch_device, trust_remote_code=True)
|
||||
|
||||
assert pipe.model.device.type == torch_device
|
||||
|
||||
def test_custom_block_loads_from_hub(self):
|
||||
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
|
||||
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
|
||||
|
||||
Reference in New Issue
Block a user