Skip to content

Commit 0a1214f

Browse files
committed
need to revise assignment behavior
1 parent b21edf1 commit 0a1214f

File tree

4 files changed

+79
-29
lines changed

4 files changed

+79
-29
lines changed

README.md

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,18 @@ However, not all references can be implemented on GPU. LuisaCompute would detect
149149

150150
The behavior can be summarize in the following table:
151151

152-
| Type | Assignment | Field/Index Assignment | Function Argument Passing | Function Return |
153-
|-------------|------------|------------------------|---------------------------|-----------------|
154-
| Python Object | Reference | Reference | Reference | Reference |
155-
| Scalar (e.g. lc.int) | Value | Value | Value | Value |
156-
| Compound Type (e.g. lc.float3, lc.float4x4) | Reference | Copy | Reference | Reference |
152+
| Type | Assignment | Field/Index Assignment | Function Argument Passing | `@lc.trace` Return | `@lc.func` Return |
153+
|-------------|------------|------------------------|---------------------------|-----------------|-------------|
154+
| Python Object | Reference | Reference | Reference | Reference | Reference |
155+
| Scalar (e.g. lc.int) | Value | N/A | Value | Value | Value |
156+
| Compound Type (e.g. lc.float3, lc.float4x4) | Reference | Copy | Reference | Reference | Value |
157157

158158

159159
Let's take a look at an example:
160160

161161
```python
162162
@lc.kernel
163-
def kernel_example():
163+
def assignment_example():
164164
s = MyStruct(10, lc.float3(1.0, 2.0, 3.0))
165165
v = lc.float3(4.0, 5.0, 6.0)
166166
t = s # t is a reference to s as in Python
@@ -171,28 +171,45 @@ def kernel_example():
171171
t2 = lc.copy(s) # t2 is a copy of s, not a reference
172172
t2.a += 1
173173
lc.print(t2.a, s.a) # should print 11, 10
174+
```
175+
176+
#### Transient vs Persistent Values
177+
Values in LuisaCompute can be categorized into transient and persistent values. Transient values are similar to rvalues in C++, meaning that they are temporarily created and hasn't bind to any variable yet. Persistent values are similar to lvalues in C++, meaning that they are bound to a variable and can be used as assignment target.
174178

175-
# the following code is not allowed since such dynamically created reference cannot be implemented on GPU:
179+
Since phyiscal reference might be supported on GPU, it is not possible to dynamically create reference to persistent values,
180+
for example
181+
```python
182+
@lc.kernel
183+
def transient_vs_persistent():
184+
v1 = lc.float3(1.0, 2.0, 3.0)
185+
v2 = lc.float3(4.0, 5.0, 6.0)
186+
187+
# the following code is allowed since both `v1 + 1.0` and `v2 + 1.0` are transient values.
188+
if dynamic_cond:
189+
dynamic = v1 + 1.0
190+
else:
191+
dynamic = v2 + 1.0
192+
193+
# the following code is not allowed since such dynamically created reference to persitent values cannot be implemented on GPU:
176194
if dynamic_cond:
177-
dynamic = s
195+
dynamic = v1
178196
else:
179-
dynamic = t2
197+
dynamic = v2
180198

181199
# instead, you can either use lc.copy() to create a copy of the struct:
182200
if dynamic_cond:
183-
dynamic = s
201+
dynamic = lc.copy(v1)
184202
else:
185-
dynamic = t2
203+
dynamic = lc.copy(v2)
186204

187205
# or use a static condition:
188206
if lc.comptime(cond):
189-
dynamic = s
207+
dynamic = v1
190208
else:
191-
dynamic = t2
209+
dynamic = v2
192210

193211
```
194212

195-
196213
### Functions and Methods
197214
Functions and methods in LuisaCompute are defined using the `@lc.func` or `@lc.trace` decorators. Both decorator transforms the python function into a LuisaCompute function that can be executed on both host (native Python) and device (LuisaCompute backend). The difference is that `@lc.trace` **inline**s the function body into the caller each time it is called, while `@lc.func` creates a separate function on the device.
198215

luisa_lang/hir.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def size(self) -> int:
8585
raise RuntimeError("RefTypes are logical and thus do not have a size")
8686

8787
def align(self) -> int:
88-
raise RuntimeError("RefTypes are logical and thus do not have an align")
88+
raise RuntimeError(
89+
"RefTypes are logical and thus do not have an align")
8990

9091
def __eq__(self, value: object) -> bool:
9192
return isinstance(value, RefType) and value.element == self.element
@@ -246,7 +247,8 @@ def __init__(self, element: Type, count: int, align: int | None = None) -> None:
246247
self.count = count
247248
self._align = align
248249
assert (self.element.size() * self.count) % self._align == 0
249-
self._size = round_to_align(self.element.size() * self.count, self._align)
250+
self._size = round_to_align(
251+
self.element.size() * self.count, self._align)
250252

251253
def size(self) -> int:
252254
return self._size
@@ -506,10 +508,11 @@ class TypeTemplate(Template[Type, TypeTemplateArgs]):
506508
pass
507509

508510

509-
class PyTreeStructure: # TODO: refactor this into another file
511+
class PyTreeStructure: # TODO: refactor this into another file
510512
metadata: (
511513
Tuple[type, Tuple[Any], Any] | None
512-
) # for JitVars, this is (type, type_args, hir.Type), for other types, this is (type, (), Any)
514+
# for JitVars, this is (type, type_args, hir.Type), for other types, this is (type, (), Any)
515+
)
513516
children: List["PyTreeStructure"]
514517

515518
def __init__(
@@ -764,7 +767,21 @@ def __init__(
764767
self.span = span
765768

766769

770+
class ValueCategory(Enum):
771+
NONE = auto() # not applicable. such as for opaque types or references
772+
TRANSIENT = auto() # or r-values
773+
PERSISTENT = auto() # or l-values
774+
775+
767776
class Value(TypedNode):
777+
category: ValueCategory
778+
779+
def __init__(self, type: Optional[Type] = None,
780+
span: Optional[Span] = None,
781+
category: ValueCategory = ValueCategory.NONE) -> None:
782+
super().__init__(type, span)
783+
self.category = category
784+
768785
def is_ref(self) -> bool:
769786
assert self.type is not None
770787
return isinstance(self.type, RefType)
@@ -791,7 +808,7 @@ def __init__(
791808
span: Optional[Span],
792809
semantic: ParameterSemantic = ParameterSemantic.BYVAL,
793810
) -> None:
794-
assert name != ""
811+
# assert name != ""
795812
assert not isinstance(type, RefType)
796813
super().__init__(type, span)
797814
self.name = name
@@ -801,7 +818,7 @@ def __init__(
801818
class VarValue(Value):
802819
var: Var
803820

804-
def __init__(self, var: Var, span: Optional[Span]) -> None:
821+
def __init__(self, var: Var, span: Optional[Span] = None) -> None:
805822
super().__init__(var.type, span)
806823
self.var = var
807824

@@ -878,7 +895,7 @@ class Alloca(Value):
878895

879896
def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
880897
# assert isinstance(ty, RefType), f"expected a RefType but got {ty}"
881-
super().__init__(RefType(ty), span)
898+
super().__init__(ty, span)
882899

883900

884901
# class Init(Value):
@@ -939,9 +956,11 @@ class Assign(Node):
939956
value: Value
940957

941958
def __init__(self, ref: Value, value: Value, span: Optional[Span] = None) -> None:
942-
assert not isinstance(value.type, (RefType)), f"expecting a non-reference value, but got {value}"
959+
assert not isinstance(
960+
value.type, (RefType)), f"expecting a non-reference value, but got {value}"
943961
if not isinstance(ref.type, RefType):
944-
raise TypeCheckError(ref, f"cannot assign to a non-reference variable")
962+
raise TypeCheckError(
963+
ref, f"cannot assign to a non-reference variable")
945964
super().__init__(span)
946965
self.ref = ref
947966
self.value = value

luisa_lang/lang_runtime.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,11 @@ def _init_symbolic(self):
326326
dsl_type = hir.get_dsl_type(self.__dtype__)
327327
if dsl_type is None:
328328
raise ValueError(f"{self.__dtype__} is not a valid DSL type")
329-
self.__symbolic__ = Symbolic(
330-
hir.VarRef(current_func().create_var(
331-
"", dsl_type.default(), False))
332-
)
329+
# self.__symbolic__ = Symbolic(
330+
# hir.VarValue(current_func().create_var(
331+
# "", dsl_type.default(), False))
332+
# )
333+
self.__symbolic__ = Symbolic(hir.Alloca(dsl_type.default(), span=None))
333334

334335
def _destroy_symbolic(self):
335336
self.__symbolic__ = None

tests/test_compile_simple.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,33 @@
33
import subprocess
44
from luisa_lang.compile import Compiler
55

6+
@lc.struct
7+
class Point:
8+
x: lc.f32
9+
y: lc.f32
10+
11+
@lc.trace
12+
def __init__(self, x: lc.f32, y: lc.f32):
13+
self.x = x
14+
self.y = y
15+
616
@lc.func
717
def sqr(x):
818
return x * x
919

1020
@lc.func
1121
def foo(a, b):
22+
# p = Point(lc.f32(1.0), lc.f32(2.0))
23+
p = Point(a, b)
24+
z = sqr(a)
1225
if a < b:
13-
return sqr(a + b)
26+
return z
1427
else:
1528
return a - b
1629

1730

1831
compiler = Compiler('cpp')
19-
compiler.compile(foo, example_inputs=(lc.f32(1.0),lc.f32(3.0)), name="foo_example")
32+
compiler.compile(foo, example_inputs=(lc.f32(1.0),lc.f32(3)), name="foo_example")
2033
output_code = compiler.output()
2134
with open("test.cpp", "w") as f:
2235
f.write(output_code)

0 commit comments

Comments
 (0)