Skip to content

Commit dfbdf3b

Browse files
committed
good
1 parent c87945b commit dfbdf3b

File tree

4 files changed

+184
-20
lines changed

4 files changed

+184
-20
lines changed

luisa_lang/codegen/cpp.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,10 @@ def impl() -> None:
415415
ty = self.base.type_cache.gen(expr.type)
416416
self.body.writeln(
417417
f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};")
418-
case hir.Intrinsic() as intrin:
418+
case hir.Intrinsic() as intrin:
419419
def do():
420+
assert intrin.type
421+
intrin_ty_s = self.base.type_cache.gen(intrin.type)
420422
intrin_name = intrin.name
421423
comps = intrin_name.split('.')
422424
gened_args = [self.gen_value_or_ref(
@@ -426,6 +428,12 @@ def do():
426428
ty = self.base.type_cache.gen(expr.type)
427429
self.body.writeln(
428430
f"{ty} v{vid}{{ {','.join(gened_args)} }};")
431+
elif comps[0] == 'cast':
432+
self.body.writeln(
433+
f"auto v{vid} = static_cast<{intrin_ty_s}>({gened_args[0]});")
434+
elif comps[0] == 'bitcast':
435+
self.body.writeln(
436+
f"auto v{vid} = lc_bit_cast<{intrin_ty_s}>({gened_args[0]});")
429437
elif comps[0] == 'cmp':
430438
cmp_dict = {
431439
'__eq__': '==',
@@ -592,11 +600,19 @@ def gen_node(self, node: hir.Node) -> Optional[hir.BasicBlock]:
592600
ty = self.base.type_cache.gen(alloca.type.remove_ref())
593601
self.body.writeln(f"{ty} v{vid}{{}}; // alloca")
594602
self.node_map[alloca] = f"v{vid}"
595-
case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member() | hir.TypeValue() | hir.FunctionValue():
603+
case hir.Print() as print_stmt:
604+
raise NotImplementedError("print statement")
605+
case hir.Assert() as assert_stmt:
606+
raise NotImplementedError("assert statement")
607+
case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member() | hir.TypeValue() | hir.FunctionValue() | hir.VarValue():
596608
if isinstance(node, hir.TypedNode) and node.is_ref():
597609
pass
598610
else:
599611
self.gen_expr(node)
612+
case hir.VarRef():
613+
pass
614+
case _:
615+
raise NotImplementedError(f"unsupported node: {node}")
600616
return None
601617

602618
def gen_bb(self, bb: hir.BasicBlock):

luisa_lang/hir.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def method(self, name: str) -> Optional[Union["Function", FunctionTemplate]]:
142142

143143
def is_concrete(self) -> bool:
144144
return True
145-
145+
146146
def is_addressable(self) -> bool:
147147
return True
148148

@@ -163,8 +163,10 @@ class RefType(Type):
163163

164164
def __init__(self, element: Type) -> None:
165165
super().__init__()
166-
assert element.is_addressable(), f"RefType element {element} is not addressable"
167-
assert not isinstance(element, (OpaqueType, RefType, FunctionType,TypeConstructorType))
166+
assert element.is_addressable(), f"RefType element {
167+
element} is not addressable"
168+
assert not isinstance(
169+
element, (OpaqueType, RefType, FunctionType, TypeConstructorType))
168170
self.element = element
169171
self.methods = element.methods
170172

@@ -189,18 +191,19 @@ def member(self, field: Any) -> Optional['Type']:
189191
ty = self.element.member(field)
190192
if ty is None:
191193
return None
192-
if isinstance(ty,FunctionType):
194+
if isinstance(ty, FunctionType):
193195
return ty
194196
return RefType(ty)
195197

196198
@override
197199
def method(self, name: str) -> Optional[Union["Function", FunctionTemplate]]:
198200
return self.element.method(name)
199-
201+
200202
@override
201203
def is_addressable(self) -> bool:
202204
return False
203205

206+
204207
class LiteralType(Type):
205208
value: Any
206209

@@ -221,7 +224,7 @@ def is_concrete(self) -> bool:
221224
@override
222225
def is_addressable(self) -> bool:
223226
return False
224-
227+
225228
def __eq__(self, value: object) -> bool:
226229
return isinstance(value, LiteralType) and value.value == self.value
227230

@@ -349,6 +352,7 @@ def is_concrete(self) -> bool:
349352
def is_addressable(self) -> bool:
350353
return False
351354

355+
352356
class GenericIntType(ScalarType):
353357
@override
354358
def __eq__(self, value: object) -> bool:
@@ -382,6 +386,7 @@ def is_concrete(self) -> bool:
382386
def is_addressable(self) -> bool:
383387
return False
384388

389+
385390
class FloatType(ScalarType):
386391
bits: int
387392

@@ -695,6 +700,7 @@ def __repr__(self) -> str:
695700
def __str__(self) -> str:
696701
return f"~{self.name}@{self.ctx_name}"
697702

703+
698704
class OpaqueType(Type):
699705
name: str
700706
extra_args: List[Any]
@@ -722,7 +728,7 @@ def __str__(self) -> str:
722728
@override
723729
def is_concrete(self) -> bool:
724730
return False
725-
731+
726732
@override
727733
def is_addressable(self) -> bool:
728734
return False
@@ -800,18 +806,19 @@ def __eq__(self, value: object) -> bool:
800806

801807
def __hash__(self) -> int:
802808
return hash((ParametricType, tuple(self.params), self.body))
803-
809+
804810
def __str__(self) -> str:
805811
return f"{self.body}[{', '.join(str(p) for p in self.params)}]"
806812

807813
@override
808814
def is_concrete(self) -> bool:
809815
return self.body.is_concrete()
810-
816+
811817
@override
812818
def is_addressable(self) -> bool:
813819
return self.body.is_addressable()
814820

821+
815822
class BoundType(Type):
816823
"""
817824
An instance of a parametric type, e.g. Foo[int]
@@ -841,7 +848,7 @@ def __eq__(self, value: object) -> bool:
841848

842849
def __hash__(self):
843850
return hash((BoundType, self.generic, tuple(self.args)))
844-
851+
845852
def __str__(self) -> str:
846853
return f"{self.generic}[{', '.join(str(a) for a in self.args)}]"
847854

@@ -862,11 +869,12 @@ def method(self, name) -> Optional[Union["Function", FunctionTemplate]]:
862869
@override
863870
def is_addressable(self) -> bool:
864871
return self.generic.is_addressable()
865-
872+
866873
@override
867874
def is_concrete(self) -> bool:
868875
return self.generic.is_concrete()
869876

877+
870878
class TypeConstructorType(Type):
871879
inner: Type
872880

@@ -910,6 +918,7 @@ def size(self) -> int:
910918
def align(self) -> int:
911919
raise RuntimeError("FunctionType has no align")
912920

921+
913922
class Node:
914923
"""
915924
Base class for all nodes in the HIR. A node could be a value, a reference, or a statement.
@@ -999,13 +1008,15 @@ def __init__(
9991008
self.name = name
10001009
self.semantic = semantic
10011010

1011+
10021012
class VarValue(Value):
10031013
var: Var
10041014

10051015
def __init__(self, var: Var, span: Optional[Span]) -> None:
10061016
super().__init__(var.type, span)
10071017
self.var = var
10081018

1019+
10091020
class VarRef(Value):
10101021
var: Var
10111022

@@ -1155,6 +1166,8 @@ def __str__(self) -> str:
11551166
return f"Template matching error:\n\t{self.message}"
11561167
return f"Template matching error at {self.span}:\n\t{self.message}"
11571168

1169+
class ComptimeCallStack:
1170+
pass
11581171

11591172
class SpannedError(Exception):
11601173
span: Span | None
@@ -1200,7 +1213,8 @@ class Assign(Node):
12001213
value: Value
12011214

12021215
def __init__(self, ref: Value, value: Value, span: Optional[Span] = None) -> None:
1203-
assert not isinstance(value.type, (FunctionType, TypeConstructorType, RefType))
1216+
assert not isinstance(
1217+
value.type, (FunctionType, TypeConstructorType, RefType))
12041218
if not isinstance(ref.type, RefType):
12051219
raise ParsingError(
12061220
ref, f"cannot assign to a non-reference variable")
@@ -1209,6 +1223,24 @@ def __init__(self, ref: Value, value: Value, span: Optional[Span] = None) -> Non
12091223
self.value = value
12101224

12111225

1226+
class Assert(Node):
1227+
cond: Value
1228+
msg: List[Union[Value, str]]
1229+
1230+
def __init__(self, cond: Value, msg: List[Union[Value, str]], span: Optional[Span] = None) -> None:
1231+
super().__init__(span)
1232+
self.cond = cond
1233+
self.msg = msg
1234+
1235+
1236+
class Print(Node):
1237+
args: List[Union[Value, str]]
1238+
1239+
def __init__(self, args: List[Union[Value, str]], span: Optional[Span] = None) -> None:
1240+
super().__init__(span)
1241+
self.args = args
1242+
1243+
12121244
class Terminator(Node):
12131245
pass
12141246

@@ -1559,6 +1591,7 @@ def __init__(self, func: Function, args: List[Value], body: BasicBlock, span: Op
15591591
self.mapping[param] = arg
15601592
for v in func.locals:
15611593
if v in self.mapping:
1594+
# skip function parameters
15621595
continue
15631596
assert v.type
15641597
assert v.type.is_addressable()
@@ -1631,6 +1664,33 @@ def do():
16311664
self.mapping[intrin] = body.append(
16321665
Intrinsic(intrin.name, args, intrin.type, node.span))
16331666
do()
1667+
case If():
1668+
cond = self.mapping.get(node.cond)
1669+
assert isinstance(cond, Value)
1670+
then_body = BasicBlock()
1671+
else_body = BasicBlock()
1672+
merge = BasicBlock()
1673+
body.append(If(cond, then_body, else_body, merge))
1674+
self.do_inline(node.then_body, then_body)
1675+
if node.else_body:
1676+
self.do_inline(node.else_body, else_body)
1677+
body.append(merge)
1678+
case Loop():
1679+
prepare = BasicBlock()
1680+
if node.cond:
1681+
cond = self.mapping.get(node.cond)
1682+
else:
1683+
cond = None
1684+
assert cond is None or isinstance(cond, Value)
1685+
body_ = BasicBlock()
1686+
update = BasicBlock()
1687+
merge = BasicBlock()
1688+
body.append(Loop(prepare, cond, body_, update, merge))
1689+
self.do_inline(node.prepare, prepare)
1690+
self.do_inline(node.body, body_)
1691+
if node.update:
1692+
self.do_inline(node.update, update)
1693+
body.append(merge)
16341694
case Return():
16351695
if self.ret is not None:
16361696
raise InlineError(node, "multiple return statement")
@@ -1646,6 +1706,9 @@ def do():
16461706
@staticmethod
16471707
def inline(func: Function, args: List[Value], body: BasicBlock, span: Optional[Span] = None) -> Value:
16481708
inliner = FunctionInliner(func, args, body, span)
1709+
assert func.return_type
1710+
if func.return_type == UnitType():
1711+
return Unit()
16491712
assert inliner.ret
16501713
return inliner.ret
16511714

luisa_lang/lang_builtins.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Any,
2222
Annotated
2323
)
24-
from luisa_lang._builtin_decor import func, intrinsic, opaque, builtin_generic_type, byref
24+
from luisa_lang._builtin_decor import func, intrinsic, opaque, builtin_generic_type, byref, struct
2525
from luisa_lang import parse
2626

2727
T = TypeVar("T")
@@ -317,6 +317,37 @@ def __sub__(self, offset: i32 | i64 | u32 | u64) -> 'Pointer[T]':
317317
return intrinsic("pointer.sub", Pointer[T], self, offset)
318318

319319

320+
@struct
321+
class RtxRay:
322+
o: float3
323+
d: float3
324+
tmin: float
325+
tmax: float
326+
327+
def __init__(self, o: float3, d: float3, tmin: float, tmax: float) -> None:
328+
self.o = o
329+
self.d = d
330+
self.tmin = tmin
331+
self.tmax = tmax
332+
333+
334+
@struct
335+
class RtxHit:
336+
inst_id: u32
337+
prim_id: u32
338+
bary: float2
339+
340+
def __init__(self, inst_id: u32, prim_id: u32, bary: float2) -> None:
341+
self.inst_id = inst_id
342+
self.prim_id = prim_id
343+
self.bary = bary
344+
345+
346+
@func
347+
def ray_query_pipeline(ray: RtxRay, on_surface_hit, on_procedural_hit) -> RtxHit:
348+
return intrinsic("ray_query_pipeline", RtxHit, ray, on_surface_hit, on_procedural_hit)
349+
350+
320351
__all__: List[str] = [
321352
# 'Pointer',
322353
'Buffer',

0 commit comments

Comments
 (0)