Skip to content

Commit 146d46d

Browse files
committed
fixed mypy errors
1 parent 33c399c commit 146d46d

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

luisa_lang/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@
66

77
from luisa_lang.lang import *
88
from luisa_lang.lang_builtins import *
9-
109
bool = boolean

luisa_lang/ast_rewrite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def rewrite_function[F: Callable[..., Any]](f: F, decorator_name: str) -> F:
474474
tree, filename = retrieve_ast_and_filename(f)
475475
tree = FuncRewriter(decorator_name, filename).visit(tree)
476476
ast.fix_missing_locations(tree)
477-
print(ast.unparse(tree))
477+
# print(ast.unparse(tree))
478478
code = compile(tree, filename="<ast>", mode="exec")
479479
local_dict: dict[Any, Any] = {}
480480
exec(code, f.__globals__, local_dict)

luisa_lang/lang_builtins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import types
22

3-
from _builtin_decor import func
3+
from luisa_lang._builtin_decor import func
44
from luisa_lang.math_types import *
55
from luisa_lang.core_types import Ref
66
import luisa_lang.hir as hir

luisa_lang/lang_runtime.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from luisa_lang.utils import IdentityDict, check_type, is_generic_class
99
import luisa_lang.hir as hir
10-
from hir import PyTreeStructure
11-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
10+
from luisa_lang.hir import PyTreeStructure
11+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast
1212

1313

1414
class Scope:
@@ -174,11 +174,11 @@ def __init__(self, node: hir.Value, scope: Scope | None = None):
174174

175175

176176
class FlattenedTree:
177-
metadata: Tuple[type, Tuple[Any], Any] # (type, type_args, Any)
177+
metadata: Tuple[Type[Any], Tuple[Any], Any] # (type, type_args, Any)
178178
children: List["FlattenedTree"]
179179

180180
def __init__(
181-
self, metadata: Tuple[type, Tuple[Any], Any], children: List["FlattenedTree"]
181+
self, metadata: Tuple[Type[Any], Tuple[Any], Any], children: List["FlattenedTree"]
182182
):
183183
self.metadata = metadata
184184
self.children = children
@@ -342,13 +342,13 @@ def tree_flatten(obj: Any, allow_non_pytree_objects: bool) -> FlattenedTree:
342342
if isinstance(obj, PyTree):
343343
return obj._flatten()
344344
if allow_non_pytree_objects and not PyTreeRegistry.is_registered(type(obj)):
345-
return FlattenedTree((type(obj), tuple(), obj), [])
345+
return FlattenedTree((type(obj), cast(Tuple[Any, ...], tuple()), obj), [])
346346
flatten_func, _ = PyTreeRegistry.get(type(obj))
347347
return flatten_func(obj)
348348

349349

350350
def tree_unflatten(obj: FlattenedTree, allow_non_pytree_objects: bool) -> Any:
351-
typ = obj.metadata[0]
351+
typ: Type[Any] = obj.metadata[0]
352352
if issubclass(typ, JitVar):
353353
_type_args, v = obj.metadata[1:]
354354
assert isinstance(v, JitVar)
@@ -409,7 +409,7 @@ def is_registered(typ: type) -> bool:
409409
@staticmethod
410410
def __register_default_types() -> None:
411411
def flatten_primitive(obj: Any) -> FlattenedTree:
412-
return FlattenedTree((type(obj), tuple(), obj), [])
412+
return FlattenedTree((type(obj), cast(Tuple[Any, ...], tuple()), obj), [])
413413

414414
def unflatten_primitive(tree: FlattenedTree) -> Any:
415415
assert len(tree.children) == 0
@@ -423,7 +423,7 @@ def unflatten_primitive(tree: FlattenedTree) -> Any:
423423

424424
def flatten_list(obj: List[Any]) -> FlattenedTree:
425425
return FlattenedTree(
426-
(list, tuple(), None), [tree_flatten(o, True) for o in obj]
426+
(list, cast(Tuple[Any, ...], tuple()), None), [tree_flatten(o, True) for o in obj]
427427
)
428428

429429
def unflatten_list(tree: FlattenedTree) -> List[Any]:
@@ -435,7 +435,7 @@ def unflatten_list(tree: FlattenedTree) -> List[Any]:
435435

436436
def flatten_tuple(obj: Tuple[Any, ...]) -> FlattenedTree:
437437
return FlattenedTree(
438-
(tuple, tuple(), None), [tree_flatten(o, True) for o in obj]
438+
(tuple, cast(Tuple[Any, ...], tuple()), None), [tree_flatten(o, True) for o in obj]
439439
)
440440

441441
def unflatten_tuple(tree: FlattenedTree) -> Tuple[Any, ...]:
@@ -447,14 +447,15 @@ def unflatten_tuple(tree: FlattenedTree) -> Tuple[Any, ...]:
447447

448448
def flatten_dict(obj: Dict[Any, Any]) -> FlattenedTree:
449449
return FlattenedTree(
450-
(dict, tuple(), (len(obj.keys()))),
450+
(dict, cast(Tuple[Any, ...], tuple()), (len(obj.keys()))),
451451
[tree_flatten(k, True) for k in obj.keys()]
452452
+ [tree_flatten(v, True) for v in obj.values()],
453453
)
454454

455455
def unflatten_dict(tree: FlattenedTree) -> Dict[Any, Any]:
456456
assert tree.metadata[0] is dict
457-
length = tree.metadata[1]
457+
length = tree.metadata[2][0]
458+
assert isinstance(length, int), "Invalid length for dict unflattening"
458459
assert len(tree.children) == length * 2
459460
keys = tree.children[:length]
460461
values = tree.children[length:]
@@ -561,7 +562,7 @@ class ControlFlowFrame:
561562
parent: Optional["ControlFlowFrame"]
562563
is_static: bool
563564

564-
def __init__(self, parent: Optional["ControlFlowFrame"]):
565+
def __init__(self, *args, parent: Optional["ControlFlowFrame"]):
565566
self.parent = parent
566567
self.is_static = False
567568

@@ -590,7 +591,7 @@ class IfFrame(ControlFlowFrame):
590591
false_bb: Optional[hir.BasicBlock]
591592

592593
def __init__(self, cond: Any, parent: ControlFlowFrame):
593-
super().__init__(parent)
594+
super().__init__(parent=parent)
594595
self.cond = cond
595596
self.is_static = not isinstance(cond, JitVar)
596597
self.static_cond = bool(cond) if self.is_static else None
@@ -712,7 +713,7 @@ class TraceContext:
712713
top_level_func: Optional[hir.Function]
713714

714715
def __init__(self, is_top_level):
715-
self.cf_frame = ControlFlowFrame(None)
716+
self.cf_frame = ControlFlowFrame(parent=None)
716717
self.is_top_level = is_top_level
717718
self.top_level_func = None
718719

0 commit comments

Comments
 (0)