Skip to content

Commit 5c5ab11

Browse files
authored
streamable: Fix default value assignments for dataclass_from_dict (#11732)
* streamable: Use constructor in `dataclass_from_dict` This fixes default value assignments after #10561 but also leads to less perfomance due to `__post_init__` being called which at least gets mitigated by #11730. * tests: Test default values with `from_json_dict` * Convert to `str`, then compare.
1 parent 1dccb68 commit 5c5ab11

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

chia/util/streamable.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def dataclass_from_dict(klass: Type[Any], item: Any) -> Any:
132132
"""
133133
if type(item) == klass:
134134
return item
135-
obj = object.__new__(klass)
135+
136136
if klass not in CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS:
137137
# For non-streamable dataclasses we can't populate the cache on startup, so we do it here for convert
138138
# functions only.
@@ -144,9 +144,13 @@ def dataclass_from_dict(klass: Type[Any], item: Any) -> Any:
144144
fields = FIELDS_FOR_STREAMABLE_CLASS[klass]
145145
convert_funcs = CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass]
146146

147-
for field, convert_func in zip(fields, convert_funcs):
148-
object.__setattr__(obj, field.name, convert_func(item[field.name]))
149-
return obj
147+
return klass(
148+
**{
149+
field.name: convert_func(item[field.name])
150+
for field, convert_func in zip(fields, convert_funcs)
151+
if field.name in item
152+
}
153+
)
150154

151155

152156
def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType:

tests/core/util/test_streamable.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import io
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Any, Dict, List, Optional, Tuple, Type
66

77
import pytest
@@ -128,15 +128,15 @@ def test_pure_dataclasses_in_dataclass_from_dict() -> None:
128128
"test_class, input_dict, error",
129129
[
130130
[TestDataclassFromDict1, {"a": "asdf", "b": "2", "c": G1Element()}, ValueError],
131-
[TestDataclassFromDict1, {"a": 1, "b": "2"}, KeyError],
131+
[TestDataclassFromDict1, {"a": 1, "b": "2"}, TypeError],
132132
[TestDataclassFromDict1, {"a": 1, "b": "2", "c": "asd"}, ValueError],
133133
[TestDataclassFromDict1, {"a": 1, "b": "2", "c": "00" * G1Element.SIZE}, ValueError],
134134
[TestDataclassFromDict1, {"a": [], "b": "2", "c": G1Element()}, TypeError],
135135
[TestDataclassFromDict1, {"a": {}, "b": "2", "c": G1Element()}, TypeError],
136136
[TestDataclassFromDict2, {"a": "asdf", "b": 1.2345, "c": 1.2345}, TypeError],
137137
[TestDataclassFromDict2, {"a": 1.2345, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError],
138-
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, KeyError],
139-
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, KeyError],
138+
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, TypeError],
139+
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError],
140140
],
141141
)
142142
def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[str, Any], error: Any) -> None:
@@ -145,6 +145,28 @@ def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[st
145145
dataclass_from_dict(test_class, input_dict)
146146

147147

148+
@streamable
149+
@dataclass(frozen=True)
150+
class TestFromJsonDictDefaultValues(Streamable):
151+
a: uint64 = uint64(1)
152+
b: str = "default"
153+
c: List[uint64] = field(default_factory=list)
154+
155+
156+
@pytest.mark.parametrize(
157+
"input_dict, output_dict",
158+
[
159+
[{}, {"a": 1, "b": "default", "c": []}],
160+
[{"a": 2}, {"a": 2, "b": "default", "c": []}],
161+
[{"b": "not_default"}, {"a": 1, "b": "not_default", "c": []}],
162+
[{"c": [1, 2]}, {"a": 1, "b": "default", "c": [1, 2]}],
163+
[{"a": 2, "b": "not_default", "c": [1, 2]}, {"a": 2, "b": "not_default", "c": [1, 2]}],
164+
],
165+
)
166+
def test_from_json_dict_default_values(input_dict: Dict[str, object], output_dict: Dict[str, object]) -> None:
167+
assert str(TestFromJsonDictDefaultValues.from_json_dict(input_dict).to_json_dict()) == str(output_dict)
168+
169+
148170
def test_basic_list() -> None:
149171
a = [1, 2, 3]
150172
assert is_type_List(type(a))

0 commit comments

Comments
 (0)