Compare commits

..

1 Commits

Author SHA1 Message Date
DN6
5c99566bab update 2026-03-01 12:46:45 +05:30
6 changed files with 176 additions and 50 deletions

View File

@@ -54,6 +54,7 @@ jobs:
python -m pip install --upgrade pip
pip install -U setuptools wheel twine
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
pip install -U transformers
- name: Build the dist files
run: python setup.py bdist_wheel && python setup.py sdist
@@ -68,8 +69,6 @@ jobs:
run: |
pip install diffusers && pip uninstall diffusers -y
pip install -i https://test.pypi.org/simple/ diffusers
pip install -U transformers
python utils/print_env.py
python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"

View File

@@ -648,6 +648,28 @@ class ConfigMixin:
)
return config_file
@classmethod
def _get_dataclass_from_config(cls, config_dict: dict[str, Any]):
sig = inspect.signature(cls.__init__)
fields = []
for name, param in sig.parameters.items():
if name == "self" or name == "kwargs" or name in cls.ignore_for_config:
continue
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any
if param.default is not inspect.Parameter.empty:
fields.append((name, annotation, dataclasses.field(default=param.default)))
else:
fields.append((name, annotation))
dc_cls = dataclasses.make_dataclass(
f"{cls.__name__}Config",
fields,
frozen=True,
)
valid_fields = {f.name for f in dataclasses.fields(dc_cls)}
init_kwargs = {k: v for k, v in config_dict.items() if k in valid_fields}
return dc_cls(**init_kwargs)
def register_to_config(init):
r"""

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import random
import tempfile
import numpy as np
import PIL
@@ -128,16 +129,18 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
return inputs
def test_save_from_pretrained(self, tmp_path):
def test_save_from_pretrained(self):
pipes = []
base_pipe = self.get_pipeline().to(torch_device)
pipes.append(base_pipe)
base_pipe.save_pretrained(tmp_path)
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
pipes.append(pipe)
@@ -209,16 +212,18 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
return inputs
def test_save_from_pretrained(self, tmp_path):
def test_save_from_pretrained(self):
pipes = []
base_pipe = self.get_pipeline().to(torch_device)
pipes.append(base_pipe)
base_pipe.save_pretrained(tmp_path)
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
pipes.append(pipe)

View File

@@ -1,5 +1,7 @@
import gc
import json
import os
import tempfile
from typing import Callable
import pytest
@@ -328,15 +330,16 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_save_from_pretrained(self, tmp_path):
def test_save_from_pretrained(self):
pipes = []
base_pipe = self.get_pipeline().to(torch_device)
pipes.append(base_pipe)
base_pipe.save_pretrained(tmp_path)
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipes.append(pipe)
@@ -348,31 +351,32 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_modular_index_consistency(self, tmp_path):
def test_modular_index_consistency(self):
pipe = self.get_pipeline()
components_spec = pipe._component_specs
components = sorted(components_spec.keys())
pipe.save_pretrained(tmp_path)
index_file = tmp_path / "modular_model_index.json"
assert index_file.exists()
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
index_file = os.path.join(tmpdir, "modular_model_index.json")
assert os.path.exists(index_file)
with open(index_file) as f:
index_contents = json.load(f)
with open(index_file) as f:
index_contents = json.load(f)
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
for k in compulsory_keys:
assert k in index_contents
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
for k in compulsory_keys:
assert k in index_contents
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
for component in components:
spec = components_spec[component]
for attr in to_check_attrs:
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
for attr in to_check_attrs:
assert component in index_contents, f"{component} should be present in index but isn't."
attr_value_from_index = index_contents[component][2][attr]
assert getattr(spec, attr) == attr_value_from_index
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
for component in components:
spec = components_spec[component]
for attr in to_check_attrs:
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
for attr in to_check_attrs:
assert component in index_contents, f"{component} should be present in index but isn't."
attr_value_from_index = index_contents[component][2][attr]
assert getattr(spec, attr) == attr_value_from_index
def test_workflow_map(self):
blocks = self.pipeline_blocks_class()

View File

@@ -14,6 +14,7 @@
import json
import os
import tempfile
from collections import deque
from typing import List
@@ -152,24 +153,25 @@ class TestModularCustomBlocks:
output_prompt = output.values["output_prompt"]
assert output_prompt.startswith("Modular diffusers + ")
def test_custom_block_saving_loading(self, tmp_path):
def test_custom_block_saving_loading(self):
custom_block = DummyCustomBlockSimple()
custom_block.save_pretrained(tmp_path)
assert any("modular_config.json" in k for k in os.listdir(tmp_path))
with tempfile.TemporaryDirectory() as tmpdir:
custom_block.save_pretrained(tmpdir)
assert any("modular_config.json" in k for k in os.listdir(tmpdir))
with open(os.path.join(tmp_path, "modular_config.json"), "r") as f:
config = json.load(f)
auto_map = config["auto_map"]
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
with open(os.path.join(tmpdir, "modular_config.json"), "r") as f:
config = json.load(f)
auto_map = config["auto_map"]
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
# This is why, we have to separately save the Python script here.
code_path = os.path.join(tmp_path, "test_modular_pipelines_custom_blocks.py")
with open(code_path, "w") as f:
f.write(CODE_STR)
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
# This is why, we have to separately save the Python script here.
code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py")
with open(code_path, "w") as f:
f.write(CODE_STR)
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmp_path, trust_remote_code=True)
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True)
pipe = loaded_custom_block.init_pipeline()
prompt = "Diffusers is nice"

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
import tempfile
import unittest
@@ -305,3 +306,96 @@ class ConfigTester(unittest.TestCase):
result = json.loads(json_string)
assert result["test_file_1"] == config.config.test_file_1.as_posix()
assert result["test_file_2"] == config.config.test_file_2.as_posix()
class SampleObjectTyped(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a: int = 2,
b: int = 5,
c: str = "hello",
):
pass
class SampleObjectWithIgnore(ConfigMixin):
config_name = "config.json"
ignore_for_config = ["secret"]
@register_to_config
def __init__(
self,
a: int = 2,
secret: str = "hidden",
):
pass
class DataclassFromConfigTester(unittest.TestCase):
def test_get_dataclass_from_config_returns_frozen_dataclass(self):
obj = SampleObject()
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert dataclasses.is_dataclass(tc)
with self.assertRaises(dataclasses.FrozenInstanceError):
tc.a = 99
def test_get_dataclass_from_config_class_name(self):
obj = SampleObject()
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert type(tc).__name__ == "SampleObjectConfig"
def test_get_dataclass_from_config_values_match_config(self):
obj = SampleObject(a=10, b=20)
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert tc.a == 10
assert tc.b == 20
assert tc.c == (2, 5)
assert tc.d == "for diffusion"
assert tc.e == [1, 3]
def test_get_dataclass_from_config_from_raw_dict(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
assert tc.a == 7
assert tc.b == 3
assert tc.c == "world"
def test_get_dataclass_from_config_annotations(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "hi"})
fields = {f.name: f.type for f in dataclasses.fields(tc)}
assert fields["a"] is int
assert fields["b"] is int
assert fields["c"] is str
def test_get_dataclass_from_config_asdict_roundtrip(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
d = dataclasses.asdict(tc)
assert d == {"a": 7, "b": 3, "c": "world"}
def test_get_dataclass_from_config_ignores_extra_keys(self):
tc = SampleObjectTyped._get_dataclass_from_config(
{"a": 1, "b": 2, "c": "hi", "_class_name": "Foo", "extra": 99}
)
assert tc.a == 1
assert not hasattr(tc, "_class_name")
assert not hasattr(tc, "extra")
def test_get_dataclass_from_config_respects_ignore_for_config(self):
tc = SampleObjectWithIgnore._get_dataclass_from_config({"a": 5})
assert not hasattr(tc, "secret")
assert tc.a == 5
def test_get_dataclass_from_config_works_for_scheduler(self):
scheduler = DDIMScheduler()
tc = DDIMScheduler._get_dataclass_from_config(dict(scheduler.config))
assert dataclasses.is_dataclass(tc)
assert type(tc).__name__ == "DDIMSchedulerConfig"
assert tc.num_train_timesteps == scheduler.config.num_train_timesteps
def test_get_dataclass_from_config_different_values(self):
tc1 = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "x"})
tc2 = SampleObjectTyped._get_dataclass_from_config({"a": 9, "b": 8, "c": "y"})
assert tc1.a == 1
assert tc2.a == 9