Skip to content

Commit 8c3c920

Browse files
committed
Add another set of variants of Tom's example and fix rules that break them
1 parent 4ff0891 commit 8c3c920

File tree

2 files changed

+91
-14
lines changed

2 files changed

+91
-14
lines changed

src/HordeAd/Core/AstSimplify.hs

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3193,16 +3193,13 @@ contractAst t = case t of
31933193
gcastWith (unsafeCoerceRefl :: (j + 1 <=? m) :~: True) $
31943194
contractAst
31953195
$ Ast.AstGatherS
3196-
shn (Ast.AstIndexS shn v (i2 :.$ ZIS))
3197-
((::$) @j (Const varm) mrest, ZIS)
3196+
shn v ((::$) @j (Const varm) mrest, i2 :.$ ZIS)
31983197
`Ast.AstAppendS`
31993198
Ast.AstGatherS
3200-
shn (Ast.AstIndexS shn v (i1 :.$ ZIS))
3201-
((::$) @1 (Const varm) mrest, ZIS)
3199+
shn v ((::$) @1 (Const varm) mrest, i1 :.$ ZIS)
32023200
`Ast.AstAppendS`
32033201
Ast.AstGatherS
3204-
shn (Ast.AstIndexS shn v (i2 :.$ ZIS))
3205-
((::$) @(m - (j + 1)) (Const varm) mrest, ZIS)
3202+
shn v ((::$) @(m - (j + 1)) (Const varm) mrest, i2 :.$ ZIS)
32063203
-- TODO: fix AstIntVar to be usable here (maybe look at SNat'?),
32073204
Ast.AstGatherS
32083205
shn v ( vars@((::$) @m (Const varm) mrest)
@@ -3219,18 +3216,15 @@ contractAst t = case t of
32193216
gcastWith (unsafeCoerceRefl :: (j + 1 <=? m) :~: True) $
32203217
contractAst
32213218
$ Ast.AstGatherS
3222-
shn (Ast.AstIndexS (ixsToShS prest `shsAppend` shn) v (i2 :.$ ZIS))
3223-
((::$) @j (Const varm) mrest, prest)
3219+
shn v ((::$) @j (Const varm) mrest, i2 :.$ prest)
32243220
`Ast.AstAppendS`
32253221
Ast.AstGatherS
3226-
shn (Ast.AstIndexS (ixsToShS prest `shsAppend` shn) v (i1 :.$ ZIS))
3227-
( (::$) @1 (Const varm) mrest
3228-
, AstPlusK (AstConcreteK j) i3 :.$ prest3)
3222+
shn v ( (::$) @1 (Const varm) mrest
3223+
, i1 :.$ AstPlusK (AstConcreteK j) i3 :.$ prest3)
32293224
`Ast.AstAppendS`
32303225
Ast.AstGatherS
3231-
shn (Ast.AstIndexS (ixsToShS prest `shsAppend` shn) v (i2 :.$ ZIS))
3232-
( (::$) @(m - (j + 1)) (Const varm) mrest
3233-
, AstPlusK (AstConcreteK $ j + 1) i3 :.$ prest3 )
3226+
shn v ( (::$) @(m - (j + 1)) (Const varm) mrest
3227+
, i2 :.$ AstPlusK (AstConcreteK $ j + 1) i3 :.$ prest3 )
32343228
Ast.AstGatherS shn v ( (::$) @m (Const varm) mrest
32353229
, (:.$) @p (AstIntVar varp) prest )
32363230
| varm == varp

test/simplified/TestGatherSimplified.hs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ testTrees =
6464
, testCase "gatherCond2" testGatherCond2
6565
, testCase "gatherCondBuild2" testGatherCondBuild2
6666
, testCase "gatherSimpCond" testGatherSimpCond
67+
, testCase "gatherCond3" testGatherCond3
68+
, testCase "gatherCondBuild3" testGatherCondBuild3
69+
, testCase "gatherCond4" testGatherCond4
70+
, testCase "gatherCondBuild4" testGatherCondBuild4
71+
, testCase "gatherSimpCond3" testGatherSimpCond3
6772

6873
, testCase "scatterNested1" testScatterNested1
6974
, testCase "scatterNestedBuild1" testScatterNestedBuild1
@@ -723,6 +728,84 @@ testGatherSimpCond = do
723728
@?= interpretAstPrimal @Concrete env t2n
724729

725730

731+
gatherCond3 :: forall target r. (ADReady target, GoodScalar r)
732+
=> target (TKR 2 r) -> target (TKR 2 r)
733+
gatherCond3 u =
734+
let v = rtranspose [2, 0, 1] $ rreplicate (2 * rwidth u) u
735+
in rgather [rwidth u, 2] v (\(i :.: j :.: ZIR) ->
736+
2 * i :.: i :.: ifH (i ==. 3) 0 j :.: ZIR)
737+
738+
testGatherCond3 :: Assertion
739+
testGatherCond3 =
740+
assertEqualUpToEpsilon' 1e-10
741+
(ringestData [7,2]
742+
[1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
743+
(rev' @Double @2 gatherCond3 (rreplicate 7 $ ringestData [2] [0, 1]))
744+
745+
testGatherCondBuild3 :: Assertion
746+
testGatherCondBuild3 =
747+
assertEqualUpToEpsilon' 1e-10
748+
(ringestData [7,2]
749+
[6.0,0.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
750+
(rev' @Double @3
751+
(\t -> rbuild1 4 (\i ->
752+
gatherCond3 (t * rreplicate0N [7, 2] (rfromIndex0 i))))
753+
(rreplicate 7 $ ringestData [2] [0, 1]))
754+
755+
gatherCond4 :: forall target r. (ADReady target, GoodScalar r)
756+
=> target (TKR 2 r) -> target (TKR 2 r)
757+
gatherCond4 u =
758+
let v = rreplicate (2 * rwidth u) u
759+
in rtr $ rgather [2, rwidth u] v (\(j :.: i :.: ZIR) ->
760+
i :.: ifH (i ==. 3) 0 j :.: 2 * i :.: ZIR)
761+
762+
testGatherCond4 :: Assertion
763+
testGatherCond4 =
764+
assertEqualUpToEpsilon' 1e-10
765+
(ringestData [7,2]
766+
[1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
767+
(rev' @Double @2 gatherCond4 (rreplicate 7 $ ringestData [2] [0, 1]))
768+
769+
testGatherCondBuild4 :: Assertion
770+
testGatherCondBuild4 =
771+
assertEqualUpToEpsilon' 1e-10
772+
(ringestData [7,2]
773+
[6.0,0.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
774+
(rev' @Double @3
775+
(\t -> rbuild1 4 (\i ->
776+
gatherCond4 (t * rreplicate0N [7, 2] (rfromIndex0 i))))
777+
(rreplicate 7 $ ringestData [2] [0, 1]))
778+
779+
testGatherSimpCond3 :: Assertion
780+
testGatherSimpCond3 = do
781+
let varName = mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000
782+
var = AstVar varName
783+
vals = [-1, 0, 2.0,5.0,11.0,-17.0,23.0,29.0,-35.0,41.0,47.0,33.0, 0.1, 0.007]
784+
env = extendEnv varName (ringestData [7, 2] vals) emptyEnv
785+
let !t1 = gatherCond3 @(AstTensor AstMethodLet PrimalSpan) var
786+
let !t2 = gatherCond4 (ringestData [7, 2] vals)
787+
let !t1n = unAstNoSimplify $ gatherCond3 $ AstNoSimplify var
788+
let !t2n = unAstNoSimplify $ gatherCond4 $ AstNoSimplify var
789+
interpretAstPrimal @Concrete env t1
790+
@?= interpretAstPrimal @Concrete env t1n
791+
interpretAstPrimal @Concrete env t1n
792+
@?= interpretAstPrimal @Concrete emptyEnv t2
793+
interpretAstPrimal @Concrete emptyEnv t2
794+
@?= interpretAstPrimal @Concrete env t2n
795+
interpretAstPrimal @Concrete env
796+
(simplifyInlineContract @(TKR 2 Float) t1)
797+
@?= interpretAstPrimal @Concrete env t1
798+
interpretAstPrimal @Concrete env
799+
(simplifyInlineContract @(TKR 2 Float) t1n)
800+
@?= interpretAstPrimal @Concrete env t1n
801+
interpretAstPrimal @Concrete emptyEnv
802+
(simplifyInlineContract @(TKR 2 Float) t2)
803+
@?= interpretAstPrimal @Concrete emptyEnv t2
804+
interpretAstPrimal @Concrete env
805+
(simplifyInlineContract @(TKR 2 Float) t2n)
806+
@?= interpretAstPrimal @Concrete env t2n
807+
808+
726809
-- * Scatters instead of gathers
727810

728811
scatterNested1 :: forall target r. (ADReady target, GoodScalar r)

0 commit comments

Comments
 (0)