Skip to content

Commit 3b37bc7

Browse files
committed
Fix test results
1 parent 151b17d commit 3b37bc7

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

test/simplified/TestAdaptorSimplified.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -890,12 +890,12 @@ testVstackBuildAstPP2 = do
890890
(printAstPretty
891891
(vstackBuild @(AstTensor AstMethodLet FullSpan) @Double
892892
(replIota2 10)))
893-
@?= "rfromS (let x14 = sfromIntegral (sscalar 0) + sscalar 2.0 * sfromIntegral (sscalar 1) ; v15 = let x11 = sfromIntegral (sscalar 9) + sscalar 3.0 * sfromIntegral (sscalar 8) ; v12 = (siota (SNat @10) + (let v2 = sconcrete (sreplicate [10] 1) + siota (SNat @10) in sgather (sconcrete (sreplicate [10] 2.0)) (\\[i8] -> [kfromS (v2 !$ [i8])]) * sgather (sfromVector (fromList [sfromIntegral v2, sconcrete (sreplicate [10] 0.0)])) (\\[i9] -> [ifH ((-9) <=. negate (kfromS (v2 !$ [i9]))) 0 1, i9]))) + (let v3 = sconcrete (sreplicate [10] (-1)) + siota (SNat @10) in sgather (sconcrete (sreplicate [10] 3.0)) (\\[i6] -> [kfromS (v3 !$ [i6])]) * sgather (sfromVector (fromList [sfromIntegral v3, sconcrete (sreplicate [10] 0.0)])) (\\[i7] -> [ifH (sscalar 0 <=. v3 !$ [i7]) 0 1, i7])) in sappend (sslice (SNat @0) (SNat @9) v12) (sreplicate @1 x11) in sappend (sreplicate @1 x14) (sslice (SNat @1) (SNat @9) v15))"
893+
@?= "rfromS (let x14 = sfromIntegral (sscalar 0) + sscalar 2.0 * sfromIntegral (sscalar 1) ; v15 = let x11 = sfromIntegral (sscalar 9) + sscalar 3.0 * sfromIntegral (sscalar 8) ; v12 = (siota (SNat @10) + (let v2 = sconcrete (sreplicate [10] 1) + siota (SNat @10) in sgather (sconcrete (sreplicate [10] 2.0)) (\\[i8] -> [kfromS (v2 !$ [i8])]) * sgather (sfromVector (fromList [sfromIntegral v2, sconcrete (sreplicate [10] 0.0)])) (\\[i9] -> [ifH (sscalar (-9) <=. negate (v2 !$ [i9])) 0 1, i9]))) + (let v3 = sconcrete (sreplicate [10] (-1)) + siota (SNat @10) in sgather (sconcrete (sreplicate [10] 3.0)) (\\[i6] -> [kfromS (v3 !$ [i6])]) * sgather (sfromVector (fromList [sfromIntegral v3, sconcrete (sreplicate [10] 0.0)])) (\\[i7] -> [ifH (sscalar 0 <=. v3 !$ [i7]) 0 1, i7])) in sappend (sslice (SNat @0) (SNat @9) v12) (sreplicate @1 x11) in sappend (sreplicate @1 x14) (sslice (SNat @1) (SNat @9) v15))"
894894
(printAstPretty
895895
(simplifyInlineContract
896896
(vstackBuild @(AstTensor AstMethodLet FullSpan) @Double
897897
(replIota2 10))))
898-
@?= "rfromS (sappend (sconcrete (sfromListLinear [1] [2.0])) (sappend (sconcrete (sfromListLinear [8] [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]) + (sgather (sconcrete (sreplicate [10] 2.0)) (\\[i23] -> [kfromS (sconcrete (sfromListLinear [10] [1,2,3,4,5,6,7,8,9,10]) !$ [1 + i23])]) * sgather (sfromVector (fromList [sfromIntegral (sconcrete (sfromListLinear [8] [2,3,4,5,6,7,8,9])), sconcrete (sreplicate [8] 0.0)])) (\\[i24] -> [ifH ((-9) <=. negate (kfromS (sconcrete (sfromListLinear [10] [1,2,3,4,5,6,7,8,9,10]) !$ [1 + i24]))) 0 1, i24]) + sgather (sconcrete (sreplicate [10] 3.0)) (\\[i21] -> [kfromS (sconcrete (sfromListLinear [10] [-1,0,1,2,3,4,5,6,7,8]) !$ [1 + i21])]) * sgather (sfromVector (fromList [sfromIntegral (sconcrete (sfromListLinear [8] [0,1,2,3,4,5,6,7])), sconcrete (sreplicate [8] 0.0)])) (\\[i22] -> [ifH (sscalar 0 <=. sconcrete (sfromListLinear [10] [-1,0,1,2,3,4,5,6,7,8]) !$ [1 + i22]) 0 1, i22]))) (sconcrete (sfromListLinear [1] [33.0]))))"
898+
@?= "rfromS (sappend (sconcrete (sfromListLinear [1] [2.0])) (sappend (sconcrete (sfromListLinear [8] [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]) + (sgather (sconcrete (sreplicate [10] 2.0)) (\\[i23] -> [kfromS (sconcrete (sfromListLinear [10] [1,2,3,4,5,6,7,8,9,10]) !$ [1 + i23])]) * sgather (sfromVector (fromList [sfromIntegral (sconcrete (sfromListLinear [8] [2,3,4,5,6,7,8,9])), sconcrete (sreplicate [8] 0.0)])) (\\[i24] -> [ifH (sscalar (-9) <=. negate (sconcrete (sfromListLinear [10] [1,2,3,4,5,6,7,8,9,10]) !$ [1 + i24])) 0 1, i24]) + sgather (sconcrete (sreplicate [10] 3.0)) (\\[i21] -> [kfromS (sconcrete (sfromListLinear [10] [-1,0,1,2,3,4,5,6,7,8]) !$ [1 + i21])]) * sgather (sfromVector (fromList [sfromIntegral (sconcrete (sfromListLinear [8] [0,1,2,3,4,5,6,7])), sconcrete (sreplicate [8] 0.0)])) (\\[i22] -> [ifH (sscalar 0 <=. sconcrete (sfromListLinear [10] [-1,0,1,2,3,4,5,6,7,8]) !$ [1 + i22]) 0 1, i22]))) (sconcrete (sfromListLinear [1] [33.0]))))"
899899
-- TODO: this reduced to "rfromS (sconcrete (sfromListLinear [10] [2.0,5.0,11.0,17.0,23.0,29.0,35.0,41.0,47.0,33.0]))" when tshare/tlet were inserted less often
900900

901901
testFooPP :: Assertion
@@ -1190,7 +1190,7 @@ testReluPP2 = do
11901190
(var3, ast3) = funToAst (FTKR [5] FTKScalar) Nothing (\t -> reluT2 (t, rscalar 7))
11911191
"\\" ++ printAstVarName var3
11921192
++ " -> " ++ printAstSimple ast3
1193-
@?= "\\v1 -> rfromS (tfromPrimal (STKS [5] STKScalar) (sgather [] (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i2] -> [ifH (sscalar 7.0 * tprimalPart (sfromR v1) !$ [i2] <=. sscalar 0.0) 0 1])) * (sfromR v1 * tfromPrimal (STKS [5] STKScalar) (sconcrete (sreplicate [5] 7.0))))"
1193+
@?= "\\v1 -> rfromS (tfromPrimal (STKS [5] STKScalar) (sgather [] (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i2] -> [ifH (sscalar -0.0 <=. sscalar (-7.0) * tprimalPart (sfromR v1) !$ [i2]) 0 1])) * (sfromR v1 * tfromPrimal (STKS [5] STKScalar) (sconcrete (sreplicate [5] 7.0))))"
11941194
resetVarCounter
11951195
let (artifactRev, _deltas) = revArtifactDelta UseIncomingCotangent reluT2 (FTKProduct (FTKR [5] FTKScalar) (FTKR ZSR FTKScalar))
11961196
printArtifactPretty (simplifyArtifact artifactRev)
@@ -1231,7 +1231,7 @@ testReluSimplerPP2 = do
12311231
(var3, ast3) = funToAst (FTKR [5] FTKScalar) Nothing (\t -> reluT2 (t, rscalar 7))
12321232
"\\" ++ printAstVarName var3
12331233
++ " -> " ++ printAstSimple ast3
1234-
@?= "\\v1 -> rfromS (tlet (sfromR v1 * tfromPrimal (STKS [5] STKScalar) (sconcrete (sreplicate [5] 7.0))) (\\v2 -> tfromPrimal (STKS [5] STKScalar) (sgather [] (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i3] -> [ifH (sfromR (tprimalPart (rfromS (v2 !$ [i3]))) <=. sscalar 0.0) 0 1])) * v2))"
1234+
@?= "\\v1 -> rfromS (tlet (sfromR v1 * tfromPrimal (STKS [5] STKScalar) (sconcrete (sreplicate [5] 7.0))) (\\v2 -> tfromPrimal (STKS [5] STKScalar) (sgather [] (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i3] -> [ifH (sscalar -0.0 <=. negate (sfromR (tprimalPart (rfromS (v2 !$ [i3]))))) 0 1])) * v2))"
12351235
resetVarCounter
12361236
let (artifactRev, _deltas) = revArtifactDelta UseIncomingCotangent reluT2 (FTKProduct (FTKR [5] FTKScalar) (FTKR ZSR FTKScalar))
12371237
printArtifactPretty (simplifyArtifact artifactRev)

test/simplified/TestMnistPP.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,11 @@ testVT2OPPNonLin3 = do
320320
artifactRevnonLin =
321321
revArtifactAdapt UseIncomingCotangent afcnn2TnonLin ftk
322322
printArtifactPrimalPretty (simplifyArtifact artifactRevnonLin)
323-
@?= "\\m1 -> let v23 = exp (sdot1In (sreplicate @2 (recip (sconcrete (sreplicate [5] 1.0) + exp (negate (scast (sdot1In (sreplicate @5 (scast (recip (sconcrete (sreplicate [4] 1.0) + exp (negate (sdot1In (sconcrete (sreplicate [4,3] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 m1))))) + negate (sfromR (tproject2 (tproject1 (tproject1 m1))))))))) (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + negate (sfromR (tproject2 (tproject2 (tproject1 m1)))))))) (sfromR (tproject1 (tproject2 m1))) + sfromR (tproject2 (tproject2 m1))) in negate (kfromS (sdot0 (sconcrete (sreplicate [2] 8.0)) (log (sreplicate @2 (recip (ssum0 v23)) * v23))))"
323+
@?= "\\m1 -> let v23 = exp (sdot1In (sreplicate @2 (recip (sconcrete (sreplicate [5] 1.0) + exp (negate (scast (sdot1In (sreplicate @5 (scast (recip (sconcrete (sreplicate [4] 1.0) + exp (negate (sdot1In (sconcrete (sreplicate [4,3] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 m1))))) + negate (sfromR (tproject2 (tproject1 (tproject1 m1))))))))) (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + negate (sfromR (tproject2 (tproject2 (tproject1 m1)))))))) (sfromR (tproject1 (tproject2 m1))) + sfromR (tproject2 (tproject2 m1))) in kfromS (negate (sdot0 (sconcrete (sreplicate [2] 8.0)) (log (sreplicate @2 (recip (ssum0 v23)) * v23))))"
324324
printArtifactPrimalPretty artifactRevnonLin
325-
@?= "\\m1 -> let v10 = ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sconcrete (sreplicate [4] 1.0) + v11 ; v13 = recip v12 ; m16 = str (sreplicate @5 (scast v13)) ; v17 = scast (ssum @4 (m16 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sconcrete (sreplicate [5] 1.0) + v18 ; v20 = recip v19 ; v23 = exp (ssum @5 (str (sreplicate @2 v20) * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum @2 v23 ; v25 = sreplicate @2 (recip x24) ; v26 = v25 * v23 ; v27 = log v26 in negate (kfromS (ssum @2 (sconcrete (sreplicate [2] 8.0) * v27)))"
325+
@?= "\\m1 -> let v10 = ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sconcrete (sreplicate [4] 1.0) + v11 ; v13 = recip v12 ; m16 = str (sreplicate @5 (scast v13)) ; v17 = scast (ssum @4 (m16 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sconcrete (sreplicate [5] 1.0) + v18 ; v20 = recip v19 ; v23 = exp (ssum @5 (str (sreplicate @2 v20) * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum @2 v23 ; v25 = sreplicate @2 (recip x24) ; v26 = v25 * v23 ; v27 = log v26 in kfromS (negate (ssum @2 (sconcrete (sreplicate [2] 8.0) * v27)))"
326326
printArtifactPretty artifactRevnonLin
327-
@?= "\\dret m1 -> let v10 = ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sconcrete (sreplicate [4] 1.0) + v11 ; v13 = recip v12 ; v14 = sconcrete (sreplicate [4] 1.0) + negate v13 ; v15 = v13 * v14 ; m16 = str (sreplicate @5 (scast v13)) ; v17 = scast (ssum @4 (m16 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sconcrete (sreplicate [5] 1.0) + v18 ; v20 = recip v19 ; v21 = sconcrete (sreplicate [5] 1.0) + negate v20 ; v22 = v20 * v21 ; v23 = exp (ssum @5 (str (sreplicate @2 v20) * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum @2 v23 ; v25 = sreplicate @2 (recip x24) ; v26 = v25 * v23 ; v29 = sconcrete (sreplicate [2] 8.0) * (recip v26 * sreplicate @2 (sfromK ((-1.0) * dret))) ; v30 = v23 * (sreplicate @2 (negate (recip (x24 * x24)) * ssum @2 (v23 * v29)) + v25 * v29) ; v31 = v22 * ssum @2 (str (str (sfromR (tproject1 (tproject2 m1))) * sreplicate @5 v30)) ; m32 = sreplicate @4 (scast v31) ; v33 = v15 * scast (ssum @5 (str (str (sfromR (tproject1 (tproject2 (tproject1 m1)))) * m32))) in tpair (tpair (tpair (rfromS (str (sconcrete (sreplicate [3,4] 7.0) * sreplicate @3 v33))) (rfromS v33)) (tpair (rfromS (str (m16 * m32))) (rfromS v31))) (tpair (rfromS (str (str (sreplicate @2 v20) * sreplicate @5 v30))) (rfromS v30))"
327+
@?= "\\dret m1 -> let v10 = ssum @3 (sconcrete (sreplicate [3,4] 7.0) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v11 = exp (negate v10) ; v12 = sconcrete (sreplicate [4] 1.0) + v11 ; v13 = recip v12 ; v14 = sconcrete (sreplicate [4] 1.0) + negate v13 ; v15 = v13 * v14 ; m16 = str (sreplicate @5 (scast v13)) ; v17 = scast (ssum @4 (m16 * str (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v18 = exp (negate v17) ; v19 = sconcrete (sreplicate [5] 1.0) + v18 ; v20 = recip v19 ; v21 = sconcrete (sreplicate [5] 1.0) + negate v20 ; v22 = v20 * v21 ; v23 = exp (ssum @5 (str (sreplicate @2 v20) * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum @2 v23 ; v25 = sreplicate @2 (recip x24) ; v26 = v25 * v23 ; v29 = sconcrete (sreplicate [2] 8.0) * (recip v26 * sreplicate @2 (sscalar (-1.0) * sfromK dret)) ; v30 = v23 * (sreplicate @2 (negate (recip (x24 * x24)) * ssum @2 (v23 * v29)) + v25 * v29) ; v31 = v22 * ssum @2 (str (str (sfromR (tproject1 (tproject2 m1))) * sreplicate @5 v30)) ; m32 = sreplicate @4 (scast v31) ; v33 = v15 * scast (ssum @5 (str (str (sfromR (tproject1 (tproject2 (tproject1 m1)))) * m32))) in tpair (tpair (tpair (rfromS (str (sconcrete (sreplicate [3,4] 7.0) * sreplicate @3 v33))) (rfromS v33)) (tpair (rfromS (str (m16 * m32))) (rfromS v31))) (tpair (rfromS (str (str (sreplicate @2 v20) * sreplicate @5 v30))) (rfromS v30))"
328328
printArtifactPretty (simplifyArtifact artifactRevnonLin)
329329
@?= "\\dret m1 -> tconvert (ConvT2 (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [4,3] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [4] FTKScalar)) ConvSX))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [5,4] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [5] FTKScalar)) ConvSX)))) (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,5] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2] FTKScalar)) ConvSX)))) (STKProduct (STKProduct (STKProduct (STKS [4,3] STKScalar) (STKS [4] STKScalar)) (STKProduct (STKS [5,4] STKScalar) (STKS [5] STKScalar))) (STKProduct (STKS [2,5] STKScalar) (STKS [2] STKScalar))) (let v13 = recip (sconcrete (sreplicate [4] 1.0) + exp (negate (sdot1In (sconcrete (sreplicate [4,3] 7.0)) (sfromR (tproject1 (tproject1 (tproject1 m1))))) + negate (sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; v16 = scast v13 ; v20 = recip (sconcrete (sreplicate [5] 1.0) + exp (negate (scast (sdot1In (sreplicate @5 v16) (sfromR (tproject1 (tproject2 (tproject1 m1)))))) + negate (sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; v23 = exp (sdot1In (sreplicate @2 v20) (sfromR (tproject1 (tproject2 m1))) + sfromR (tproject2 (tproject2 m1))) ; x24 = ssum0 v23 ; x25 = recip x24 ; v29 = sconcrete (sreplicate [2] 8.0) * (recip (sreplicate @2 x25 * v23) * sreplicate @2 (sscalar (-1.0) * sfromK dret)) ; v30 = v23 * (sreplicate @2 (negate (recip (x24 * x24)) * sdot0 v23 v29) + sreplicate @2 x25 * v29) ; v31 = (v20 * (sconcrete (sreplicate [5] 1.0) + negate v20)) * sdot1In (str (sfromR (tproject1 (tproject2 m1)))) (sreplicate @5 v30) ; v32 = scast v31 ; v33 = (v13 * (sconcrete (sreplicate [4] 1.0) + negate v13)) * scast (sdot1In (str (sfromR (tproject1 (tproject2 (tproject1 m1))))) (sreplicate @4 v32)) in tpair (tpair (tpair (sconcrete (sreplicate [4,3] 7.0) * str (sreplicate @3 v33)) v33) (tpair (sreplicate @5 v16 * str (sreplicate @4 v32)) v31)) (tpair (sreplicate @2 v20 * str (sreplicate @5 v30)) v30))"
330330

0 commit comments

Comments
 (0)