@@ -890,12 +890,12 @@ testVstackBuildAstPP2 = do
890
890
(printAstPretty
891
891
(vstackBuild @ (AstTensor AstMethodLet FullSpan ) @ Double
892
892
(replIota2 10 )))
893
- @?= " rfromS (let x14 = sfromIntegral (sscalar 0) + sscalar 2.0 * sfromIntegral (sscalar 1) ; v15 = let x11 = sfromIntegral (sscalar 9) + sscalar 3.0 * sfromIntegral (sscalar 8) ; v12 = (siota (SNat @10) + (let v2 = sconcrete (sreplicate [10] 1) + siota (SNat @10) in sgather (sconcrete (sreplicate [10] 2.0)) (\\ [i8] -> [kfromS (v2 !$ [i8])]) * sgather (sfromVector (fromList [sfromIntegral v2, sconcrete (sreplicate [10] 0.0)])) (\\ [i9] -> [ifH ((-9) <=. negate (kfromS ( v2 !$ [i9]) )) 0 1, i9]))) + (let v3 = sconcrete (sreplicate [10] (-1)) + siota (SNat @10) in sgather (sconcrete (sreplicate [10] 3.0)) (\\ [i6] -> [kfromS (v3 !$ [i6])]) * sgather (sfromVector (fromList [sfromIntegral v3, sconcrete (sreplicate [10] 0.0)])) (\\ [i7] -> [ifH (sscalar 0 <=. v3 !$ [i7]) 0 1, i7])) in sappend (sslice (SNat @0) (SNat @9) v12) (sreplicate @1 x11) in sappend (sreplicate @1 x14) (sslice (SNat @1) (SNat @9) v15))"
893
+ @?= " rfromS (let x14 = sfromIntegral (sscalar 0) + sscalar 2.0 * sfromIntegral (sscalar 1) ; v15 = let x11 = sfromIntegral (sscalar 9) + sscalar 3.0 * sfromIntegral (sscalar 8) ; v12 = (siota (SNat @10) + (let v2 = sconcrete (sreplicate [10] 1) + siota (SNat @10) in sgather (sconcrete (sreplicate [10] 2.0)) (\\ [i8] -> [kfromS (v2 !$ [i8])]) * sgather (sfromVector (fromList [sfromIntegral v2, sconcrete (sreplicate [10] 0.0)])) (\\ [i9] -> [ifH (sscalar (-9) <=. negate (v2 !$ [i9])) 0 1, i9]))) + (let v3 = sconcrete (sreplicate [10] (-1)) + siota (SNat @10) in sgather (sconcrete (sreplicate [10] 3.0)) (\\ [i6] -> [kfromS (v3 !$ [i6])]) * sgather (sfromVector (fromList [sfromIntegral v3, sconcrete (sreplicate [10] 0.0)])) (\\ [i7] -> [ifH (sscalar 0 <=. v3 !$ [i7]) 0 1, i7])) in sappend (sslice (SNat @0) (SNat @9) v12) (sreplicate @1 x11) in sappend (sreplicate @1 x14) (sslice (SNat @1) (SNat @9) v15))"
894
894
(printAstPretty
895
895
(simplifyInlineContract
896
896
(vstackBuild @ (AstTensor AstMethodLet FullSpan ) @ Double
897
897
(replIota2 10 ))))
898
- @?= " rfromS (sappend (sconcrete (sfromListLinear [1] [2.0])) (sappend (sconcrete (sfromListLinear [8] [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]) + (sgather (sconcrete (sreplicate [10] 2.0)) (\\ [i23] -> [kfromS (sconcrete (sfromListLinear [10] [1,2,3,4,5,6,7,8,9,10]) !$ [1 + i23])]) * sgather (sfromVector (fromList [sfromIntegral (sconcrete (sfromListLinear [8] [2,3,4,5,6,7,8,9])), sconcrete (sreplicate [8] 0.0)])) (\\ [i24] -> [ifH ((-9) <=. negate (kfromS ( sconcrete (sfromListLinear [10] [1,2,3,4,5,6,7,8,9,10]) !$ [1 + i24]) )) 0 1, i24]) + sgather (sconcrete (sreplicate [10] 3.0)) (\\ [i21] -> [kfromS (sconcrete (sfromListLinear [10] [-1,0,1,2,3,4,5,6,7,8]) !$ [1 + i21])]) * sgather (sfromVector (fromList [sfromIntegral (sconcrete (sfromListLinear [8] [0,1,2,3,4,5,6,7])), sconcrete (sreplicate [8] 0.0)])) (\\ [i22] -> [ifH (sscalar 0 <=. sconcrete (sfromListLinear [10] [-1,0,1,2,3,4,5,6,7,8]) !$ [1 + i22]) 0 1, i22]))) (sconcrete (sfromListLinear [1] [33.0]))))"
898
+ @?= " rfromS (sappend (sconcrete (sfromListLinear [1] [2.0])) (sappend (sconcrete (sfromListLinear [8] [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]) + (sgather (sconcrete (sreplicate [10] 2.0)) (\\ [i23] -> [kfromS (sconcrete (sfromListLinear [10] [1,2,3,4,5,6,7,8,9,10]) !$ [1 + i23])]) * sgather (sfromVector (fromList [sfromIntegral (sconcrete (sfromListLinear [8] [2,3,4,5,6,7,8,9])), sconcrete (sreplicate [8] 0.0)])) (\\ [i24] -> [ifH (sscalar (-9) <=. negate (sconcrete (sfromListLinear [10] [1,2,3,4,5,6,7,8,9,10]) !$ [1 + i24])) 0 1, i24]) + sgather (sconcrete (sreplicate [10] 3.0)) (\\ [i21] -> [kfromS (sconcrete (sfromListLinear [10] [-1,0,1,2,3,4,5,6,7,8]) !$ [1 + i21])]) * sgather (sfromVector (fromList [sfromIntegral (sconcrete (sfromListLinear [8] [0,1,2,3,4,5,6,7])), sconcrete (sreplicate [8] 0.0)])) (\\ [i22] -> [ifH (sscalar 0 <=. sconcrete (sfromListLinear [10] [-1,0,1,2,3,4,5,6,7,8]) !$ [1 + i22]) 0 1, i22]))) (sconcrete (sfromListLinear [1] [33.0]))))"
899
899
-- TODO: this reduced to "rfromS (sconcrete (sfromListLinear [10] [2.0,5.0,11.0,17.0,23.0,29.0,35.0,41.0,47.0,33.0]))" when tshare/tlet were inserted less often
900
900
901
901
testFooPP :: Assertion
@@ -1190,7 +1190,7 @@ testReluPP2 = do
1190
1190
(var3, ast3) = funToAst (FTKR [5 ] FTKScalar ) Nothing (\ t -> reluT2 (t, rscalar 7 ))
1191
1191
" \\ " ++ printAstVarName var3
1192
1192
++ " -> " ++ printAstSimple ast3
1193
- @?= " \\ v1 -> rfromS (tfromPrimal (STKS [5] STKScalar) (sgather [] (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\ [i2] -> [ifH (sscalar 7.0 * tprimalPart (sfromR v1) !$ [i2] <=. sscalar 0.0 ) 0 1])) * (sfromR v1 * tfromPrimal (STKS [5] STKScalar) (sconcrete (sreplicate [5] 7.0))))"
1193
+ @?= " \\ v1 -> rfromS (tfromPrimal (STKS [5] STKScalar) (sgather [] (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\ [i2] -> [ifH (sscalar -0.0 <=. sscalar (- 7.0) * tprimalPart (sfromR v1) !$ [i2]) 0 1])) * (sfromR v1 * tfromPrimal (STKS [5] STKScalar) (sconcrete (sreplicate [5] 7.0))))"
1194
1194
resetVarCounter
1195
1195
let (artifactRev, _deltas) = revArtifactDelta UseIncomingCotangent reluT2 (FTKProduct (FTKR [5 ] FTKScalar ) (FTKR ZSR FTKScalar ))
1196
1196
printArtifactPretty (simplifyArtifact artifactRev)
@@ -1231,7 +1231,7 @@ testReluSimplerPP2 = do
1231
1231
(var3, ast3) = funToAst (FTKR [5 ] FTKScalar ) Nothing (\ t -> reluT2 (t, rscalar 7 ))
1232
1232
" \\ " ++ printAstVarName var3
1233
1233
++ " -> " ++ printAstSimple ast3
1234
- @?= " \\ v1 -> rfromS (tlet (sfromR v1 * tfromPrimal (STKS [5] STKScalar) (sconcrete (sreplicate [5] 7.0))) (\\ v2 -> tfromPrimal (STKS [5] STKScalar) (sgather [] (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\ [i3] -> [ifH (sfromR (tprimalPart (rfromS (v2 !$ [i3]))) <=. sscalar 0.0 ) 0 1])) * v2))"
1234
+ @?= " \\ v1 -> rfromS (tlet (sfromR v1 * tfromPrimal (STKS [5] STKScalar) (sconcrete (sreplicate [5] 7.0))) (\\ v2 -> tfromPrimal (STKS [5] STKScalar) (sgather [] (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\ [i3] -> [ifH (sscalar -0.0 <=. negate ( sfromR (tprimalPart (rfromS (v2 !$ [i3])))) ) 0 1])) * v2))"
1235
1235
resetVarCounter
1236
1236
let (artifactRev, _deltas) = revArtifactDelta UseIncomingCotangent reluT2 (FTKProduct (FTKR [5 ] FTKScalar ) (FTKR ZSR FTKScalar ))
1237
1237
printArtifactPretty (simplifyArtifact artifactRev)
0 commit comments