7
7
8
8
from luisa_lang .utils import IdentityDict , check_type , is_generic_class
9
9
import luisa_lang .hir as hir
10
- from hir import PyTreeStructure
11
- from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Tuple , Union , cast
10
+ from luisa_lang . hir import PyTreeStructure
11
+ from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Tuple , Type , Union , cast
12
12
13
13
14
14
class Scope :
@@ -174,11 +174,11 @@ def __init__(self, node: hir.Value, scope: Scope | None = None):
174
174
175
175
176
176
class FlattenedTree :
177
- metadata : Tuple [type , Tuple [Any ], Any ] # (type, type_args, Any)
177
+ metadata : Tuple [Type [ Any ] , Tuple [Any ], Any ] # (type, type_args, Any)
178
178
children : List ["FlattenedTree" ]
179
179
180
180
def __init__ (
181
- self , metadata : Tuple [type , Tuple [Any ], Any ], children : List ["FlattenedTree" ]
181
+ self , metadata : Tuple [Type [ Any ] , Tuple [Any ], Any ], children : List ["FlattenedTree" ]
182
182
):
183
183
self .metadata = metadata
184
184
self .children = children
@@ -342,13 +342,13 @@ def tree_flatten(obj: Any, allow_non_pytree_objects: bool) -> FlattenedTree:
342
342
if isinstance (obj , PyTree ):
343
343
return obj ._flatten ()
344
344
if allow_non_pytree_objects and not PyTreeRegistry .is_registered (type (obj )):
345
- return FlattenedTree ((type (obj ), tuple (), obj ), [])
345
+ return FlattenedTree ((type (obj ), cast ( Tuple [ Any , ...], tuple () ), obj ), [])
346
346
flatten_func , _ = PyTreeRegistry .get (type (obj ))
347
347
return flatten_func (obj )
348
348
349
349
350
350
def tree_unflatten (obj : FlattenedTree , allow_non_pytree_objects : bool ) -> Any :
351
- typ = obj .metadata [0 ]
351
+ typ : Type [ Any ] = obj .metadata [0 ]
352
352
if issubclass (typ , JitVar ):
353
353
_type_args , v = obj .metadata [1 :]
354
354
assert isinstance (v , JitVar )
@@ -409,7 +409,7 @@ def is_registered(typ: type) -> bool:
409
409
@staticmethod
410
410
def __register_default_types () -> None :
411
411
def flatten_primitive (obj : Any ) -> FlattenedTree :
412
- return FlattenedTree ((type (obj ), tuple (), obj ), [])
412
+ return FlattenedTree ((type (obj ), cast ( Tuple [ Any , ...], tuple () ), obj ), [])
413
413
414
414
def unflatten_primitive (tree : FlattenedTree ) -> Any :
415
415
assert len (tree .children ) == 0
@@ -423,7 +423,7 @@ def unflatten_primitive(tree: FlattenedTree) -> Any:
423
423
424
424
def flatten_list (obj : List [Any ]) -> FlattenedTree :
425
425
return FlattenedTree (
426
- (list , tuple (), None ), [tree_flatten (o , True ) for o in obj ]
426
+ (list , cast ( Tuple [ Any , ...], tuple () ), None ), [tree_flatten (o , True ) for o in obj ]
427
427
)
428
428
429
429
def unflatten_list (tree : FlattenedTree ) -> List [Any ]:
@@ -435,7 +435,7 @@ def unflatten_list(tree: FlattenedTree) -> List[Any]:
435
435
436
436
def flatten_tuple (obj : Tuple [Any , ...]) -> FlattenedTree :
437
437
return FlattenedTree (
438
- (tuple , tuple (), None ), [tree_flatten (o , True ) for o in obj ]
438
+ (tuple , cast ( Tuple [ Any , ...], tuple () ), None ), [tree_flatten (o , True ) for o in obj ]
439
439
)
440
440
441
441
def unflatten_tuple (tree : FlattenedTree ) -> Tuple [Any , ...]:
@@ -447,14 +447,15 @@ def unflatten_tuple(tree: FlattenedTree) -> Tuple[Any, ...]:
447
447
448
448
def flatten_dict (obj : Dict [Any , Any ]) -> FlattenedTree :
449
449
return FlattenedTree (
450
- (dict , tuple (), (len (obj .keys ()))),
450
+ (dict , cast ( Tuple [ Any , ...], tuple () ), (len (obj .keys ()))),
451
451
[tree_flatten (k , True ) for k in obj .keys ()]
452
452
+ [tree_flatten (v , True ) for v in obj .values ()],
453
453
)
454
454
455
455
def unflatten_dict (tree : FlattenedTree ) -> Dict [Any , Any ]:
456
456
assert tree .metadata [0 ] is dict
457
- length = tree .metadata [1 ]
457
+ length = tree .metadata [2 ][0 ]
458
+ assert isinstance (length , int ), "Invalid length for dict unflattening"
458
459
assert len (tree .children ) == length * 2
459
460
keys = tree .children [:length ]
460
461
values = tree .children [length :]
@@ -561,7 +562,7 @@ class ControlFlowFrame:
561
562
parent : Optional ["ControlFlowFrame" ]
562
563
is_static : bool
563
564
564
- def __init__ (self , parent : Optional ["ControlFlowFrame" ]):
565
+ def __init__ (self , * args , parent : Optional ["ControlFlowFrame" ]):
565
566
self .parent = parent
566
567
self .is_static = False
567
568
@@ -590,7 +591,7 @@ class IfFrame(ControlFlowFrame):
590
591
false_bb : Optional [hir .BasicBlock ]
591
592
592
593
def __init__ (self , cond : Any , parent : ControlFlowFrame ):
593
- super ().__init__ (parent )
594
+ super ().__init__ (parent = parent )
594
595
self .cond = cond
595
596
self .is_static = not isinstance (cond , JitVar )
596
597
self .static_cond = bool (cond ) if self .is_static else None
@@ -712,7 +713,7 @@ class TraceContext:
712
713
top_level_func : Optional [hir .Function ]
713
714
714
715
def __init__ (self , is_top_level ):
715
- self .cf_frame = ControlFlowFrame (None )
716
+ self .cf_frame = ControlFlowFrame (parent = None )
716
717
self .is_top_level = is_top_level
717
718
self .top_level_func = None
718
719
0 commit comments