@@ -64,6 +64,11 @@ testTrees =
64
64
, testCase " gatherCond2" testGatherCond2
65
65
, testCase " gatherCondBuild2" testGatherCondBuild2
66
66
, testCase " gatherSimpCond" testGatherSimpCond
67
+ , testCase " gatherCond3" testGatherCond3
68
+ , testCase " gatherCondBuild3" testGatherCondBuild3
69
+ , testCase " gatherCond4" testGatherCond4
70
+ , testCase " gatherCondBuild4" testGatherCondBuild4
71
+ , testCase " gatherSimpCond3" testGatherSimpCond3
67
72
68
73
, testCase " scatterNested1" testScatterNested1
69
74
, testCase " scatterNestedBuild1" testScatterNestedBuild1
@@ -723,6 +728,84 @@ testGatherSimpCond = do
723
728
@?= interpretAstPrimal @ Concrete env t2n
724
729
725
730
731
+ gatherCond3 :: forall target r . (ADReady target , GoodScalar r )
732
+ => target (TKR 2 r ) -> target (TKR 2 r )
733
+ gatherCond3 u =
734
+ let v = rtranspose [2 , 0 , 1 ] $ rreplicate (2 * rwidth u) u
735
+ in rgather [rwidth u, 2 ] v (\ (i :.: j :.: ZIR ) ->
736
+ 2 * i :.: i :.: ifH (i ==. 3 ) 0 j :.: ZIR )
737
+
738
+ testGatherCond3 :: Assertion
739
+ testGatherCond3 =
740
+ assertEqualUpToEpsilon' 1e-10
741
+ (ringestData [7 ,2 ]
742
+ [1.0 ,0.0 ,1.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ])
743
+ (rev' @ Double @ 2 gatherCond3 (rreplicate 7 $ ringestData [2 ] [0 , 1 ]))
744
+
745
+ testGatherCondBuild3 :: Assertion
746
+ testGatherCondBuild3 =
747
+ assertEqualUpToEpsilon' 1e-10
748
+ (ringestData [7 ,2 ]
749
+ [6.0 ,0.0 ,6.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ])
750
+ (rev' @ Double @ 3
751
+ (\ t -> rbuild1 4 (\ i ->
752
+ gatherCond3 (t * rreplicate0N [7 , 2 ] (rfromIndex0 i))))
753
+ (rreplicate 7 $ ringestData [2 ] [0 , 1 ]))
754
+
755
+ gatherCond4 :: forall target r . (ADReady target , GoodScalar r )
756
+ => target (TKR 2 r ) -> target (TKR 2 r )
757
+ gatherCond4 u =
758
+ let v = rreplicate (2 * rwidth u) u
759
+ in rtr $ rgather [2 , rwidth u] v (\ (j :.: i :.: ZIR ) ->
760
+ i :.: ifH (i ==. 3 ) 0 j :.: 2 * i :.: ZIR )
761
+
762
+ testGatherCond4 :: Assertion
763
+ testGatherCond4 =
764
+ assertEqualUpToEpsilon' 1e-10
765
+ (ringestData [7 ,2 ]
766
+ [1.0 ,0.0 ,1.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ])
767
+ (rev' @ Double @ 2 gatherCond4 (rreplicate 7 $ ringestData [2 ] [0 , 1 ]))
768
+
769
+ testGatherCondBuild4 :: Assertion
770
+ testGatherCondBuild4 =
771
+ assertEqualUpToEpsilon' 1e-10
772
+ (ringestData [7 ,2 ]
773
+ [6.0 ,0.0 ,6.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ,0.0 ])
774
+ (rev' @ Double @ 3
775
+ (\ t -> rbuild1 4 (\ i ->
776
+ gatherCond4 (t * rreplicate0N [7 , 2 ] (rfromIndex0 i))))
777
+ (rreplicate 7 $ ringestData [2 ] [0 , 1 ]))
778
+
779
+ testGatherSimpCond3 :: Assertion
780
+ testGatherSimpCond3 = do
781
+ let varName = mkAstVarName (FTKR [7 , 2 ] FTKScalar ) . intToAstVarId $ 100000000
782
+ var = AstVar varName
783
+ vals = [- 1 , 0 , 2.0 ,5.0 ,11.0 ,- 17.0 ,23.0 ,29.0 ,- 35.0 ,41.0 ,47.0 ,33.0 , 0.1 , 0.007 ]
784
+ env = extendEnv varName (ringestData [7 , 2 ] vals) emptyEnv
785
+ let ! t1 = gatherCond3 @ (AstTensor AstMethodLet PrimalSpan ) var
786
+ let ! t2 = gatherCond4 (ringestData [7 , 2 ] vals)
787
+ let ! t1n = unAstNoSimplify $ gatherCond3 $ AstNoSimplify var
788
+ let ! t2n = unAstNoSimplify $ gatherCond4 $ AstNoSimplify var
789
+ interpretAstPrimal @ Concrete env t1
790
+ @?= interpretAstPrimal @ Concrete env t1n
791
+ interpretAstPrimal @ Concrete env t1n
792
+ @?= interpretAstPrimal @ Concrete emptyEnv t2
793
+ interpretAstPrimal @ Concrete emptyEnv t2
794
+ @?= interpretAstPrimal @ Concrete env t2n
795
+ interpretAstPrimal @ Concrete env
796
+ (simplifyInlineContract @ (TKR 2 Float ) t1)
797
+ @?= interpretAstPrimal @ Concrete env t1
798
+ interpretAstPrimal @ Concrete env
799
+ (simplifyInlineContract @ (TKR 2 Float ) t1n)
800
+ @?= interpretAstPrimal @ Concrete env t1n
801
+ interpretAstPrimal @ Concrete emptyEnv
802
+ (simplifyInlineContract @ (TKR 2 Float ) t2)
803
+ @?= interpretAstPrimal @ Concrete emptyEnv t2
804
+ interpretAstPrimal @ Concrete env
805
+ (simplifyInlineContract @ (TKR 2 Float ) t2n)
806
+ @?= interpretAstPrimal @ Concrete env t2n
807
+
808
+
726
809
-- * Scatters instead of gathers
727
810
728
811
scatterNested1 :: forall target r . (ADReady target , GoodScalar r )
0 commit comments