mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-11 18:22:04 +08:00
Compare commits
1 Commits
make-tiny-
...
dataclass-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c99566bab |
@@ -648,6 +648,28 @@ class ConfigMixin:
|
|||||||
)
|
)
|
||||||
return config_file
|
return config_file
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_dataclass_from_config(cls, config_dict: dict[str, Any]):
|
||||||
|
sig = inspect.signature(cls.__init__)
|
||||||
|
fields = []
|
||||||
|
for name, param in sig.parameters.items():
|
||||||
|
if name == "self" or name == "kwargs" or name in cls.ignore_for_config:
|
||||||
|
continue
|
||||||
|
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any
|
||||||
|
if param.default is not inspect.Parameter.empty:
|
||||||
|
fields.append((name, annotation, dataclasses.field(default=param.default)))
|
||||||
|
else:
|
||||||
|
fields.append((name, annotation))
|
||||||
|
|
||||||
|
dc_cls = dataclasses.make_dataclass(
|
||||||
|
f"{cls.__name__}Config",
|
||||||
|
fields,
|
||||||
|
frozen=True,
|
||||||
|
)
|
||||||
|
valid_fields = {f.name for f in dataclasses.fields(dc_cls)}
|
||||||
|
init_kwargs = {k: v for k, v in config_dict.items() if k in valid_fields}
|
||||||
|
return dc_cls(**init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def register_to_config(init):
|
def register_to_config(init):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -305,3 +306,96 @@ class ConfigTester(unittest.TestCase):
|
|||||||
result = json.loads(json_string)
|
result = json.loads(json_string)
|
||||||
assert result["test_file_1"] == config.config.test_file_1.as_posix()
|
assert result["test_file_1"] == config.config.test_file_1.as_posix()
|
||||||
assert result["test_file_2"] == config.config.test_file_2.as_posix()
|
assert result["test_file_2"] == config.config.test_file_2.as_posix()
|
||||||
|
|
||||||
|
|
||||||
|
class SampleObjectTyped(ConfigMixin):
|
||||||
|
config_name = "config.json"
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
a: int = 2,
|
||||||
|
b: int = 5,
|
||||||
|
c: str = "hello",
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SampleObjectWithIgnore(ConfigMixin):
|
||||||
|
config_name = "config.json"
|
||||||
|
ignore_for_config = ["secret"]
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
a: int = 2,
|
||||||
|
secret: str = "hidden",
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataclassFromConfigTester(unittest.TestCase):
|
||||||
|
def test_get_dataclass_from_config_returns_frozen_dataclass(self):
|
||||||
|
obj = SampleObject()
|
||||||
|
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||||
|
assert dataclasses.is_dataclass(tc)
|
||||||
|
with self.assertRaises(dataclasses.FrozenInstanceError):
|
||||||
|
tc.a = 99
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_class_name(self):
|
||||||
|
obj = SampleObject()
|
||||||
|
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||||
|
assert type(tc).__name__ == "SampleObjectConfig"
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_values_match_config(self):
|
||||||
|
obj = SampleObject(a=10, b=20)
|
||||||
|
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||||
|
assert tc.a == 10
|
||||||
|
assert tc.b == 20
|
||||||
|
assert tc.c == (2, 5)
|
||||||
|
assert tc.d == "for diffusion"
|
||||||
|
assert tc.e == [1, 3]
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_from_raw_dict(self):
|
||||||
|
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
|
||||||
|
assert tc.a == 7
|
||||||
|
assert tc.b == 3
|
||||||
|
assert tc.c == "world"
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_annotations(self):
|
||||||
|
tc = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "hi"})
|
||||||
|
fields = {f.name: f.type for f in dataclasses.fields(tc)}
|
||||||
|
assert fields["a"] is int
|
||||||
|
assert fields["b"] is int
|
||||||
|
assert fields["c"] is str
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_asdict_roundtrip(self):
|
||||||
|
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
|
||||||
|
d = dataclasses.asdict(tc)
|
||||||
|
assert d == {"a": 7, "b": 3, "c": "world"}
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_ignores_extra_keys(self):
|
||||||
|
tc = SampleObjectTyped._get_dataclass_from_config(
|
||||||
|
{"a": 1, "b": 2, "c": "hi", "_class_name": "Foo", "extra": 99}
|
||||||
|
)
|
||||||
|
assert tc.a == 1
|
||||||
|
assert not hasattr(tc, "_class_name")
|
||||||
|
assert not hasattr(tc, "extra")
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_respects_ignore_for_config(self):
|
||||||
|
tc = SampleObjectWithIgnore._get_dataclass_from_config({"a": 5})
|
||||||
|
assert not hasattr(tc, "secret")
|
||||||
|
assert tc.a == 5
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_works_for_scheduler(self):
|
||||||
|
scheduler = DDIMScheduler()
|
||||||
|
tc = DDIMScheduler._get_dataclass_from_config(dict(scheduler.config))
|
||||||
|
assert dataclasses.is_dataclass(tc)
|
||||||
|
assert type(tc).__name__ == "DDIMSchedulerConfig"
|
||||||
|
assert tc.num_train_timesteps == scheduler.config.num_train_timesteps
|
||||||
|
|
||||||
|
def test_get_dataclass_from_config_different_values(self):
|
||||||
|
tc1 = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "x"})
|
||||||
|
tc2 = SampleObjectTyped._get_dataclass_from_config({"a": 9, "b": 8, "c": "y"})
|
||||||
|
assert tc1.a == 1
|
||||||
|
assert tc2.a == 9
|
||||||
|
|||||||
Reference in New Issue
Block a user