mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-02 06:40:40 +08:00
Compare commits
1 Commits
main
...
dataclass-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c99566bab |
@@ -648,6 +648,28 @@ class ConfigMixin:
|
||||
)
|
||||
return config_file
|
||||
|
||||
@classmethod
|
||||
def _get_dataclass_from_config(cls, config_dict: dict[str, Any]):
|
||||
sig = inspect.signature(cls.__init__)
|
||||
fields = []
|
||||
for name, param in sig.parameters.items():
|
||||
if name == "self" or name == "kwargs" or name in cls.ignore_for_config:
|
||||
continue
|
||||
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
fields.append((name, annotation, dataclasses.field(default=param.default)))
|
||||
else:
|
||||
fields.append((name, annotation))
|
||||
|
||||
dc_cls = dataclasses.make_dataclass(
|
||||
f"{cls.__name__}Config",
|
||||
fields,
|
||||
frozen=True,
|
||||
)
|
||||
valid_fields = {f.name for f in dataclasses.fields(dc_cls)}
|
||||
init_kwargs = {k: v for k, v in config_dict.items() if k in valid_fields}
|
||||
return dc_cls(**init_kwargs)
|
||||
|
||||
|
||||
def register_to_config(init):
|
||||
r"""
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -305,3 +306,96 @@ class ConfigTester(unittest.TestCase):
|
||||
result = json.loads(json_string)
|
||||
assert result["test_file_1"] == config.config.test_file_1.as_posix()
|
||||
assert result["test_file_2"] == config.config.test_file_2.as_posix()
|
||||
|
||||
|
||||
class SampleObjectTyped(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a: int = 2,
|
||||
b: int = 5,
|
||||
c: str = "hello",
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SampleObjectWithIgnore(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
ignore_for_config = ["secret"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a: int = 2,
|
||||
secret: str = "hidden",
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class DataclassFromConfigTester(unittest.TestCase):
|
||||
def test_get_dataclass_from_config_returns_frozen_dataclass(self):
|
||||
obj = SampleObject()
|
||||
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||
assert dataclasses.is_dataclass(tc)
|
||||
with self.assertRaises(dataclasses.FrozenInstanceError):
|
||||
tc.a = 99
|
||||
|
||||
def test_get_dataclass_from_config_class_name(self):
|
||||
obj = SampleObject()
|
||||
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||
assert type(tc).__name__ == "SampleObjectConfig"
|
||||
|
||||
def test_get_dataclass_from_config_values_match_config(self):
|
||||
obj = SampleObject(a=10, b=20)
|
||||
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||
assert tc.a == 10
|
||||
assert tc.b == 20
|
||||
assert tc.c == (2, 5)
|
||||
assert tc.d == "for diffusion"
|
||||
assert tc.e == [1, 3]
|
||||
|
||||
def test_get_dataclass_from_config_from_raw_dict(self):
|
||||
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
|
||||
assert tc.a == 7
|
||||
assert tc.b == 3
|
||||
assert tc.c == "world"
|
||||
|
||||
def test_get_dataclass_from_config_annotations(self):
|
||||
tc = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "hi"})
|
||||
fields = {f.name: f.type for f in dataclasses.fields(tc)}
|
||||
assert fields["a"] is int
|
||||
assert fields["b"] is int
|
||||
assert fields["c"] is str
|
||||
|
||||
def test_get_dataclass_from_config_asdict_roundtrip(self):
|
||||
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
|
||||
d = dataclasses.asdict(tc)
|
||||
assert d == {"a": 7, "b": 3, "c": "world"}
|
||||
|
||||
def test_get_dataclass_from_config_ignores_extra_keys(self):
|
||||
tc = SampleObjectTyped._get_dataclass_from_config(
|
||||
{"a": 1, "b": 2, "c": "hi", "_class_name": "Foo", "extra": 99}
|
||||
)
|
||||
assert tc.a == 1
|
||||
assert not hasattr(tc, "_class_name")
|
||||
assert not hasattr(tc, "extra")
|
||||
|
||||
def test_get_dataclass_from_config_respects_ignore_for_config(self):
|
||||
tc = SampleObjectWithIgnore._get_dataclass_from_config({"a": 5})
|
||||
assert not hasattr(tc, "secret")
|
||||
assert tc.a == 5
|
||||
|
||||
def test_get_dataclass_from_config_works_for_scheduler(self):
|
||||
scheduler = DDIMScheduler()
|
||||
tc = DDIMScheduler._get_dataclass_from_config(dict(scheduler.config))
|
||||
assert dataclasses.is_dataclass(tc)
|
||||
assert type(tc).__name__ == "DDIMSchedulerConfig"
|
||||
assert tc.num_train_timesteps == scheduler.config.num_train_timesteps
|
||||
|
||||
def test_get_dataclass_from_config_different_values(self):
|
||||
tc1 = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "x"})
|
||||
tc2 = SampleObjectTyped._get_dataclass_from_config({"a": 9, "b": 8, "c": "y"})
|
||||
assert tc1.a == 1
|
||||
assert tc2.a == 9
|
||||
|
||||
Reference in New Issue
Block a user