Skip to content

Commit dda517e

Browse files
streamable: Cache convert functions from dataclass_from_dict (#10561)
* streamable: Cache convert functions for dict -> dataclass conversion * tests: Test `dataclass_from_dict` with non-streamable classes * `Any` -> `object` Co-authored-by: Kyle Altendorf <[email protected]> * Move comment into `dataclass_from_dict` Co-authored-by: Kyle Altendorf <[email protected]>
1 parent c400d81 commit dda517e

File tree

2 files changed

+151
-32
lines changed

2 files changed

+151
-32
lines changed

chia/util/streamable.py

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ class DefinitionError(StreamableError):
4646

4747
ParseFunctionType = Callable[[BinaryIO], object]
4848
StreamFunctionType = Callable[[object, BinaryIO], None]
49+
ConvertFunctionType = Callable[[object], object]
4950

5051

5152
# Caches to store the fields and (de)serialization methods for all available streamable classes.
5253
FIELDS_FOR_STREAMABLE_CLASS: Dict[Type[object], Dict[str, Type[object]]] = {}
5354
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {}
5455
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {}
56+
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {}
5557

5658

5759
def is_type_List(f_type: object) -> bool:
@@ -69,45 +71,105 @@ def is_type_Tuple(f_type: object) -> bool:
6971
return get_origin(f_type) == tuple or f_type == tuple
7072

7173

72-
def dataclass_from_dict(klass: Type[Any], d: Any) -> Any:
74+
def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any:
75+
if item is None:
76+
return None
77+
return convert_func(item)
78+
79+
80+
def convert_tuple(convert_funcs: List[ConvertFunctionType], items: Tuple[Any, ...]) -> Tuple[Any, ...]:
81+
tuple_data = []
82+
for i in range(len(items)):
83+
tuple_data.append(convert_funcs[i](items[i]))
84+
return tuple(tuple_data)
85+
86+
87+
def convert_list(convert_func: ConvertFunctionType, items: List[Any]) -> List[Any]:
88+
list_data = []
89+
for item in items:
90+
list_data.append(convert_func(item))
91+
return list_data
92+
93+
94+
def convert_byte_type(f_type: Type[Any], item: Any) -> Any:
95+
if type(item) == f_type:
96+
return item
97+
return f_type(hexstr_to_bytes(item))
98+
99+
100+
def convert_unhashable_type(f_type: Type[Any], item: Any) -> Any:
101+
if type(item) == f_type:
102+
return item
103+
if hasattr(f_type, "from_bytes_unchecked"):
104+
from_bytes_method = f_type.from_bytes_unchecked
105+
else:
106+
from_bytes_method = f_type.from_bytes
107+
return from_bytes_method(hexstr_to_bytes(item))
108+
109+
110+
def convert_primitive(f_type: Type[Any], item: Any) -> Any:
111+
if type(item) == f_type:
112+
return item
113+
return f_type(item)
114+
115+
116+
def dataclass_from_dict(klass: Type[Any], item: Any) -> Any:
73117
"""
74118
Converts a dictionary based on a dataclass, into an instance of that dataclass.
75119
Recursively goes through lists, optionals, and dictionaries.
76120
"""
77-
if is_type_SpecificOptional(klass):
78-
# Type is optional, data is either None, or Any
79-
if d is None:
80-
return None
81-
return dataclass_from_dict(get_args(klass)[0], d)
82-
elif is_type_Tuple(klass):
83-
# Type is tuple, can have multiple different types inside
84-
i = 0
85-
klass_properties = []
86-
for item in d:
87-
klass_properties.append(dataclass_from_dict(klass.__args__[i], item))
88-
i = i + 1
89-
return tuple(klass_properties)
90-
elif dataclasses.is_dataclass(klass):
91-
# Type is a dataclass, data is a dictionary
121+
if type(item) == klass:
122+
return item
123+
obj = object.__new__(klass)
124+
if klass not in CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS:
125+
# For non-streamable dataclasses we can't populate the cache on startup, so we do it here for convert
126+
# functions only.
127+
convert_funcs = []
92128
hints = get_type_hints(klass)
93-
fieldtypes = {f.name: hints.get(f.name, f.type) for f in dataclasses.fields(klass)}
94-
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
95-
elif is_type_List(klass):
96-
# Type is a list, data is a list
97-
return [dataclass_from_dict(get_args(klass)[0], item) for item in d]
98-
elif issubclass(klass, bytes):
99-
# Type is bytes, data is a hex string
100-
return klass(hexstr_to_bytes(d))
101-
elif klass.__name__ in unhashable_types:
129+
fields = {field.name: hints.get(field.name, field.type) for field in dataclasses.fields(klass)}
130+
131+
for _, f_type in fields.items():
132+
convert_funcs.append(function_to_convert_one_item(f_type))
133+
134+
FIELDS_FOR_STREAMABLE_CLASS[klass] = fields
135+
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass] = convert_funcs
136+
else:
137+
fields = FIELDS_FOR_STREAMABLE_CLASS[klass]
138+
convert_funcs = CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass]
139+
140+
for field, convert_func in zip(fields, convert_funcs):
141+
object.__setattr__(obj, field, convert_func(item[field]))
142+
return obj
143+
144+
145+
def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType:
146+
if is_type_SpecificOptional(f_type):
147+
convert_inner_func = function_to_convert_one_item(get_args(f_type)[0])
148+
return lambda item: convert_optional(convert_inner_func, item)
149+
elif is_type_Tuple(f_type):
150+
args = get_args(f_type)
151+
convert_inner_tuple_funcs = []
152+
for arg in args:
153+
convert_inner_tuple_funcs.append(function_to_convert_one_item(arg))
154+
# Ignoring for now as the proper solution isn't obvious
155+
return lambda items: convert_tuple(convert_inner_tuple_funcs, items) # type: ignore[arg-type]
156+
elif is_type_List(f_type):
157+
inner_type = get_args(f_type)[0]
158+
convert_inner_func = function_to_convert_one_item(inner_type)
159+
# Ignoring for now as the proper solution isn't obvious
160+
return lambda items: convert_list(convert_inner_func, items) # type: ignore[arg-type]
161+
elif dataclasses.is_dataclass(f_type):
162+
# Type is a dataclass, data is a dictionary
163+
return lambda item: dataclass_from_dict(f_type, item)
164+
elif issubclass(f_type, bytes):
165+
# Type is bytes, data is a hex string or bytes
166+
return lambda item: convert_byte_type(f_type, item)
167+
elif f_type.__name__ in unhashable_types:
102168
# Type is unhashable (bls type), so cast from hex string
103-
if hasattr(klass, "from_bytes_unchecked"):
104-
from_bytes_method: Callable[[bytes], Any] = klass.from_bytes_unchecked
105-
else:
106-
from_bytes_method = klass.from_bytes
107-
return from_bytes_method(hexstr_to_bytes(d))
169+
return lambda item: convert_unhashable_type(f_type, item)
108170
else:
109171
# Type is a primitive, cast with correct class
110-
return klass(d)
172+
return lambda item: convert_primitive(f_type, item)
111173

112174

113175
def recurse_jsonify(d: Any) -> Any:
@@ -295,6 +357,7 @@ class Example(Streamable):
295357

296358
stream_functions = []
297359
parse_functions = []
360+
convert_functions = []
298361
try:
299362
hints = get_type_hints(cls)
300363
fields = {field.name: hints.get(field.name, field.type) for field in dataclasses.fields(cls)}
@@ -306,9 +369,11 @@ class Example(Streamable):
306369
for _, f_type in fields.items():
307370
stream_functions.append(cls.function_to_stream_one_item(f_type))
308371
parse_functions.append(cls.function_to_parse_one_item(f_type))
372+
convert_functions.append(function_to_convert_one_item(f_type))
309373

310374
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = stream_functions
311375
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = parse_functions
376+
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = convert_functions
312377
return cls
313378

314379

tests/core/util/test_streamable.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import io
44
from dataclasses import dataclass
5-
from typing import Dict, List, Optional, Tuple
5+
from typing import Any, Dict, List, Optional, Tuple, Type
66

77
import pytest
8+
from blspy import G1Element
89
from clvm_tools import binutils
910
from typing_extensions import Literal
1011

@@ -18,6 +19,7 @@
1819
from chia.util.streamable import (
1920
DefinitionError,
2021
Streamable,
22+
dataclass_from_dict,
2123
is_type_List,
2224
is_type_SpecificOptional,
2325
parse_bool,
@@ -91,6 +93,58 @@ class TestClassPlain(Streamable):
9193
a: PlainClass
9294

9395

96+
@dataclass
97+
class TestDataclassFromDict1:
98+
a: int
99+
b: str
100+
c: G1Element
101+
102+
103+
@dataclass
104+
class TestDataclassFromDict2:
105+
a: TestDataclassFromDict1
106+
b: TestDataclassFromDict1
107+
c: float
108+
109+
110+
def test_pure_dataclasses_in_dataclass_from_dict() -> None:
111+
112+
d1_dict = {"a": 1, "b": "2", "c": str(G1Element())}
113+
114+
d1: TestDataclassFromDict1 = dataclass_from_dict(TestDataclassFromDict1, d1_dict)
115+
assert d1.a == 1
116+
assert d1.b == "2"
117+
assert d1.c == G1Element()
118+
119+
d2_dict = {"a": d1, "b": d1_dict, "c": 1.2345}
120+
121+
d2: TestDataclassFromDict2 = dataclass_from_dict(TestDataclassFromDict2, d2_dict)
122+
assert d2.a == d1
123+
assert d2.b == d1
124+
assert d2.c == 1.2345
125+
126+
127+
@pytest.mark.parametrize(
128+
"test_class, input_dict, error",
129+
[
130+
[TestDataclassFromDict1, {"a": "asdf", "b": "2", "c": G1Element()}, ValueError],
131+
[TestDataclassFromDict1, {"a": 1, "b": "2"}, KeyError],
132+
[TestDataclassFromDict1, {"a": 1, "b": "2", "c": "asd"}, ValueError],
133+
[TestDataclassFromDict1, {"a": 1, "b": "2", "c": "00" * G1Element.SIZE}, ValueError],
134+
[TestDataclassFromDict1, {"a": [], "b": "2", "c": G1Element()}, TypeError],
135+
[TestDataclassFromDict1, {"a": {}, "b": "2", "c": G1Element()}, TypeError],
136+
[TestDataclassFromDict2, {"a": "asdf", "b": 1.2345, "c": 1.2345}, TypeError],
137+
[TestDataclassFromDict2, {"a": 1.2345, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError],
138+
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, KeyError],
139+
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, KeyError],
140+
],
141+
)
142+
def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[str, Any], error: Any) -> None:
143+
144+
with pytest.raises(error):
145+
dataclass_from_dict(test_class, input_dict)
146+
147+
94148
def test_basic_list() -> None:
95149
a = [1, 2, 3]
96150
assert is_type_List(type(a))

0 commit comments

Comments
 (0)