@@ -46,12 +46,14 @@ class DefinitionError(StreamableError):
46
46
47
47
ParseFunctionType = Callable [[BinaryIO ], object ]
48
48
StreamFunctionType = Callable [[object , BinaryIO ], None ]
49
+ ConvertFunctionType = Callable [[object ], object ]
49
50
50
51
51
52
# Caches to store the fields and (de)serialization methods for all available streamable classes.
52
53
FIELDS_FOR_STREAMABLE_CLASS : Dict [Type [object ], Dict [str , Type [object ]]] = {}
53
54
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS : Dict [Type [object ], List [StreamFunctionType ]] = {}
54
55
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS : Dict [Type [object ], List [ParseFunctionType ]] = {}
56
+ CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS : Dict [Type [object ], List [ConvertFunctionType ]] = {}
55
57
56
58
57
59
def is_type_List (f_type : object ) -> bool :
@@ -69,45 +71,105 @@ def is_type_Tuple(f_type: object) -> bool:
69
71
return get_origin (f_type ) == tuple or f_type == tuple
70
72
71
73
72
- def dataclass_from_dict (klass : Type [Any ], d : Any ) -> Any :
74
+ def convert_optional (convert_func : ConvertFunctionType , item : Any ) -> Any :
75
+ if item is None :
76
+ return None
77
+ return convert_func (item )
78
+
79
+
80
+ def convert_tuple (convert_funcs : List [ConvertFunctionType ], items : Tuple [Any , ...]) -> Tuple [Any , ...]:
81
+ tuple_data = []
82
+ for i in range (len (items )):
83
+ tuple_data .append (convert_funcs [i ](items [i ]))
84
+ return tuple (tuple_data )
85
+
86
+
87
+ def convert_list (convert_func : ConvertFunctionType , items : List [Any ]) -> List [Any ]:
88
+ list_data = []
89
+ for item in items :
90
+ list_data .append (convert_func (item ))
91
+ return list_data
92
+
93
+
94
+ def convert_byte_type (f_type : Type [Any ], item : Any ) -> Any :
95
+ if type (item ) == f_type :
96
+ return item
97
+ return f_type (hexstr_to_bytes (item ))
98
+
99
+
100
+ def convert_unhashable_type (f_type : Type [Any ], item : Any ) -> Any :
101
+ if type (item ) == f_type :
102
+ return item
103
+ if hasattr (f_type , "from_bytes_unchecked" ):
104
+ from_bytes_method = f_type .from_bytes_unchecked
105
+ else :
106
+ from_bytes_method = f_type .from_bytes
107
+ return from_bytes_method (hexstr_to_bytes (item ))
108
+
109
+
110
+ def convert_primitive (f_type : Type [Any ], item : Any ) -> Any :
111
+ if type (item ) == f_type :
112
+ return item
113
+ return f_type (item )
114
+
115
+
116
+ def dataclass_from_dict (klass : Type [Any ], item : Any ) -> Any :
73
117
"""
74
118
Converts a dictionary based on a dataclass, into an instance of that dataclass.
75
119
Recursively goes through lists, optionals, and dictionaries.
76
120
"""
77
- if is_type_SpecificOptional (klass ):
78
- # Type is optional, data is either None, or Any
79
- if d is None :
80
- return None
81
- return dataclass_from_dict (get_args (klass )[0 ], d )
82
- elif is_type_Tuple (klass ):
83
- # Type is tuple, can have multiple different types inside
84
- i = 0
85
- klass_properties = []
86
- for item in d :
87
- klass_properties .append (dataclass_from_dict (klass .__args__ [i ], item ))
88
- i = i + 1
89
- return tuple (klass_properties )
90
- elif dataclasses .is_dataclass (klass ):
91
- # Type is a dataclass, data is a dictionary
121
+ if type (item ) == klass :
122
+ return item
123
+ obj = object .__new__ (klass )
124
+ if klass not in CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS :
125
+ # For non-streamable dataclasses we can't populate the cache on startup, so we do it here for convert
126
+ # functions only.
127
+ convert_funcs = []
92
128
hints = get_type_hints (klass )
93
- fieldtypes = {f .name : hints .get (f .name , f .type ) for f in dataclasses .fields (klass )}
94
- return klass (** {f : dataclass_from_dict (fieldtypes [f ], d [f ]) for f in d })
95
- elif is_type_List (klass ):
96
- # Type is a list, data is a list
97
- return [dataclass_from_dict (get_args (klass )[0 ], item ) for item in d ]
98
- elif issubclass (klass , bytes ):
99
- # Type is bytes, data is a hex string
100
- return klass (hexstr_to_bytes (d ))
101
- elif klass .__name__ in unhashable_types :
129
+ fields = {field .name : hints .get (field .name , field .type ) for field in dataclasses .fields (klass )}
130
+
131
+ for _ , f_type in fields .items ():
132
+ convert_funcs .append (function_to_convert_one_item (f_type ))
133
+
134
+ FIELDS_FOR_STREAMABLE_CLASS [klass ] = fields
135
+ CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS [klass ] = convert_funcs
136
+ else :
137
+ fields = FIELDS_FOR_STREAMABLE_CLASS [klass ]
138
+ convert_funcs = CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS [klass ]
139
+
140
+ for field , convert_func in zip (fields , convert_funcs ):
141
+ object .__setattr__ (obj , field , convert_func (item [field ]))
142
+ return obj
143
+
144
+
145
+ def function_to_convert_one_item (f_type : Type [Any ]) -> ConvertFunctionType :
146
+ if is_type_SpecificOptional (f_type ):
147
+ convert_inner_func = function_to_convert_one_item (get_args (f_type )[0 ])
148
+ return lambda item : convert_optional (convert_inner_func , item )
149
+ elif is_type_Tuple (f_type ):
150
+ args = get_args (f_type )
151
+ convert_inner_tuple_funcs = []
152
+ for arg in args :
153
+ convert_inner_tuple_funcs .append (function_to_convert_one_item (arg ))
154
+ # Ignoring for now as the proper solution isn't obvious
155
+ return lambda items : convert_tuple (convert_inner_tuple_funcs , items ) # type: ignore[arg-type]
156
+ elif is_type_List (f_type ):
157
+ inner_type = get_args (f_type )[0 ]
158
+ convert_inner_func = function_to_convert_one_item (inner_type )
159
+ # Ignoring for now as the proper solution isn't obvious
160
+ return lambda items : convert_list (convert_inner_func , items ) # type: ignore[arg-type]
161
+ elif dataclasses .is_dataclass (f_type ):
162
+ # Type is a dataclass, data is a dictionary
163
+ return lambda item : dataclass_from_dict (f_type , item )
164
+ elif issubclass (f_type , bytes ):
165
+ # Type is bytes, data is a hex string or bytes
166
+ return lambda item : convert_byte_type (f_type , item )
167
+ elif f_type .__name__ in unhashable_types :
102
168
# Type is unhashable (bls type), so cast from hex string
103
- if hasattr (klass , "from_bytes_unchecked" ):
104
- from_bytes_method : Callable [[bytes ], Any ] = klass .from_bytes_unchecked
105
- else :
106
- from_bytes_method = klass .from_bytes
107
- return from_bytes_method (hexstr_to_bytes (d ))
169
+ return lambda item : convert_unhashable_type (f_type , item )
108
170
else :
109
171
# Type is a primitive, cast with correct class
110
- return klass ( d )
172
+ return lambda item : convert_primitive ( f_type , item )
111
173
112
174
113
175
def recurse_jsonify (d : Any ) -> Any :
@@ -295,6 +357,7 @@ class Example(Streamable):
295
357
296
358
stream_functions = []
297
359
parse_functions = []
360
+ convert_functions = []
298
361
try :
299
362
hints = get_type_hints (cls )
300
363
fields = {field .name : hints .get (field .name , field .type ) for field in dataclasses .fields (cls )}
@@ -306,9 +369,11 @@ class Example(Streamable):
306
369
for _ , f_type in fields .items ():
307
370
stream_functions .append (cls .function_to_stream_one_item (f_type ))
308
371
parse_functions .append (cls .function_to_parse_one_item (f_type ))
372
+ convert_functions .append (function_to_convert_one_item (f_type ))
309
373
310
374
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS [cls ] = stream_functions
311
375
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS [cls ] = parse_functions
376
+ CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS [cls ] = convert_functions
312
377
return cls
313
378
314
379
0 commit comments