Skip to content

Commit 620894e

Browse files
committed
Replace contractAstTimesK by *
1 parent 6efbf08 commit 620894e

File tree

8 files changed

+175
-153
lines changed

8 files changed

+175
-153
lines changed

src/HordeAd/Core/AstSimplify.hs

Lines changed: 13 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ astPrimalPart t = case t of
887887
in astConcrete ftk (tconstantTarget 0 ftk)
888888

889889
AstPlusK u v -> astPrimalPart u + astPrimalPart v
890-
AstTimesK u v -> contractAstTimesK (astPrimalPart u) (astPrimalPart v)
890+
AstTimesK u v -> astPrimalPart u * astPrimalPart v
891891
AstN1K opCode u -> contractAstNumOp1 opCode (astPrimalPart u)
892892
Ast.AstR1K opCode u -> Ast.AstR1K opCode (astPrimalPart u)
893893
Ast.AstR2K opCode u v -> Ast.AstR2K opCode (astPrimalPart u) (astPrimalPart v)
@@ -970,7 +970,7 @@ astDualPart t = case t of
970970
Ast.AstFromDual v -> v
971971

972972
AstPlusK u v -> astDualPart u + astDualPart v
973-
AstTimesK u v -> contractAstTimesK (astDualPart u) (astDualPart v)
973+
AstTimesK u v -> astDualPart u * astDualPart v
974974
AstN1K opCode u -> contractAstNumOp1 opCode (astDualPart u)
975975
Ast.AstR1K opCode u -> Ast.AstR1K opCode (astDualPart u)
976976
Ast.AstR2K opCode u v -> Ast.AstR2K opCode (astDualPart u) (astDualPart v)
@@ -1093,7 +1093,7 @@ astFromIntegralS t = case t of
10931093
Ast.AstBuild1 snat STKScalar (var, astFromIntegralK v)
10941094
AstConcreteS a -> AstConcreteS (tfromIntegralS a)
10951095
Ast.AstLet var u v -> astLet var u (astFromIntegralS v)
1096-
AstN1S NegateOp u ->negate (astFromIntegralS u)
1096+
AstN1S NegateOp u -> negate (astFromIntegralS u)
10971097
AstN1S opCode u -> AstN1S opCode (astFromIntegralS u)
10981098
-- AstN2S opCode u v -> AstN2S opCode (astFromIntegralS u) (astFromIntegralS v)
10991099
-- Ast.AstI2S opCode u v ->
@@ -2585,7 +2585,7 @@ astNonIndexStep t = case t of
25852585
Ast.AstFromDual{} -> t
25862586

25872587
AstPlusK u v -> u + v
2588-
AstTimesK u v -> contractAstTimesK u v
2588+
AstTimesK u v -> u * v
25892589
AstN1K opCode u -> contractAstNumOp1 opCode u
25902590
Ast.AstR1K{} -> t
25912591
Ast.AstR2K{} -> t
@@ -2689,7 +2689,7 @@ expandAst t = case t of
26892689
Ast.AstFromDual v -> Ast.AstFromDual (expandAst v)
26902690

26912691
AstPlusK u v -> expandAst u + expandAst v
2692-
AstTimesK u v -> contractAstTimesK (expandAst u) (expandAst v)
2692+
AstTimesK u v -> expandAst u * expandAst v
26932693
AstN1K opCode u -> contractAstNumOp1 opCode (expandAst u)
26942694
Ast.AstR1K opCode u -> Ast.AstR1K opCode (expandAst u)
26952695
Ast.AstR2K opCode u v -> Ast.AstR2K opCode (expandAst u) (expandAst v)
@@ -2872,7 +2872,7 @@ simplifyAst t = case t of
28722872
Ast.AstFromDual v -> Ast.AstFromDual (simplifyAst v)
28732873

28742874
AstPlusK u v -> simplifyAst u + simplifyAst v
2875-
AstTimesK u v -> contractAstTimesK (simplifyAst u) (simplifyAst v)
2875+
AstTimesK u v -> simplifyAst u * simplifyAst v
28762876
AstN1K opCode u -> contractAstNumOp1 opCode (simplifyAst u)
28772877
Ast.AstR1K opCode u -> Ast.AstR1K opCode (simplifyAst u)
28782878
Ast.AstR2K opCode u v -> Ast.AstR2K opCode (simplifyAst u) (simplifyAst v)
@@ -3227,7 +3227,7 @@ contractAst t = case t of
32273227
Ast.AstFromDual v -> Ast.AstFromDual (contractAst v)
32283228

32293229
AstPlusK u v -> contractAst u + contractAst v
3230-
AstTimesK u v -> contractAstTimesK (contractAst u) (contractAst v)
3230+
AstTimesK u v -> contractAst u * contractAst v
32313231
AstN1K opCode u -> contractAstNumOp1 opCode (contractAst u)
32323232
Ast.AstR1K opCode u -> Ast.AstR1K opCode (contractAst u)
32333233
Ast.AstR2K opCode u v -> Ast.AstR2K opCode (contractAst u) (contractAst v)
@@ -3384,9 +3384,12 @@ contractRelOp GtOp (Ast.AstVar _ u) (Ast.AstVar _ v) | u == v =
33843384
AstBoolConst False
33853385
contractRelOp opCodeRel arg1 arg2 = Ast.AstRelK opCodeRel arg1 arg2
33863386

3387+
-- TODO: perhaps aim for a polynomial normal form? but that requires global
3388+
-- inspection of the whole expression
33873389
-- TODO: let's aim at SOP (Sum-of-Products) form, just as
33883390
-- ghc-typelits-natnormalise does. Also, let's associate to the right
33893391
-- and let's push negation down.
3392+
-- TODO: these docs are outdated
33903393
--
33913394
-- | Normally, we wouldn't simplify tensor arithmetic so much, but some
33923395
-- of these ranked tensors can represent integers in indexes, so we have to.
@@ -3420,79 +3423,16 @@ contractRelOp opCodeRel arg1 arg2 = Ast.AstRelK opCodeRel arg1 arg2
34203423
contractAstNumOp1 :: (GoodScalar r, AstSpan s)
34213424
=> OpCodeNum1 -> AstTensor AstMethodLet s (TKScalar r)
34223425
-> AstTensor AstMethodLet s (TKScalar r)
3423-
contractAstNumOp1 NegateOp (AstConcreteK u) = AstConcreteK (negate u)
3424-
contractAstNumOp1 NegateOp (AstPlusK u v) =
3425-
AstPlusK (contractAstNumOp1 NegateOp u) (contractAstNumOp1 NegateOp v)
3426-
contractAstNumOp1 NegateOp (AstTimesK (AstConcreteK u) v) =
3427-
contractAstTimesK (AstConcreteK (negate u)) v
3428-
-- given a choice, prefer to negate a constant
3429-
contractAstNumOp1 NegateOp (AstTimesK u v) =
3430-
contractAstTimesK u (contractAstNumOp1 NegateOp v)
3431-
contractAstNumOp1 NegateOp (AstN1K NegateOp u) = u
3432-
contractAstNumOp1 NegateOp (AstN1K SignumOp u) =
3433-
contractAstNumOp1 SignumOp (contractAstNumOp1 NegateOp u)
3434-
contractAstNumOp1 NegateOp (Ast.AstI2K QuotOp u v) =
3435-
contractAstIntegralOp2 QuotOp (contractAstNumOp1 NegateOp u) v
3436-
-- v is likely positive and let's keep it so
3437-
contractAstNumOp1 NegateOp (Ast.AstI2K RemOp u v) =
3438-
contractAstIntegralOp2 RemOp (contractAstNumOp1 NegateOp u) v
3439-
-- v is likely positive and let's keep it so
3440-
3426+
contractAstNumOp1 NegateOp u = negate u
34413427
contractAstNumOp1 AbsOp (AstConcreteK u) = AstConcreteK (abs u)
34423428
contractAstNumOp1 AbsOp (AstN1K AbsOp u) = AstN1K AbsOp u
34433429
contractAstNumOp1 AbsOp (AstN1K NegateOp u) = contractAstNumOp1 AbsOp u
34443430
contractAstNumOp1 SignumOp (AstConcreteK u) = AstConcreteK (signum u)
34453431
contractAstNumOp1 SignumOp (AstN1K SignumOp u) = AstN1K SignumOp u
34463432
contractAstNumOp1 SignumOp (AstN1K AbsOp u) =
34473433
contractAstNumOp1 AbsOp (AstN1K SignumOp u)
3448-
34493434
contractAstNumOp1 opCode u = AstN1K opCode u
34503435

3451-
-- As with AstPlusK, AstConcreteK is kept on the left.
3452-
contractAstTimesK :: (GoodScalar r, AstSpan s)
3453-
=> AstTensor AstMethodLet s (TKScalar r)
3454-
-> AstTensor AstMethodLet s (TKScalar r)
3455-
-> AstTensor AstMethodLet s (TKScalar r)
3456-
contractAstTimesK (AstConcreteK 0) _v = AstConcreteK 0
3457-
contractAstTimesK (AstConcreteK 1) v = v
3458-
{- TODO: is it worth adding AstLet with a fresh variables
3459-
to share w and so make these rules safe? Perhaps after we decide
3460-
a normal form (e.g., a polynomial)?
3461-
contractAstTimesK (AstN2K PlusOp (u, v), w) =
3462-
contractAstTimesK ( contractAstTimesK (u, w)
3463-
, contractAstTimesK (v, w) )
3464-
contractAstTimesK (u, AstN2K PlusOp (v, w)) =
3465-
contractAstTimesK ( contractAstTimesK (u, v)
3466-
, contractAstTimesK (u, w) )
3467-
-}
3468-
contractAstTimesK (AstConcreteK u) (AstConcreteK v) = AstConcreteK (u * v)
3469-
contractAstTimesK (AstConcreteK u) (AstTimesK (AstConcreteK v) w) =
3470-
contractAstTimesK (AstConcreteK (u * v)) w
3471-
contractAstTimesK (AstTimesK (AstConcreteK u) v)
3472-
(AstTimesK (AstConcreteK w) x) =
3473-
AstTimesK (AstConcreteK (u * w)) (AstTimesK v x)
3474-
contractAstTimesK u w@AstConcreteK{} = contractAstTimesK w u
3475-
contractAstTimesK u (AstTimesK v@AstConcreteK{} w) =
3476-
contractAstTimesK v (AstTimesK u w)
3477-
-- TODO: perhaps aim for a polynomial normal form? but that requires global
3478-
-- inspection of the whole expression
3479-
contractAstTimesK u@AstConcreteK{} (AstPlusK v w) = AstPlusK (contractAstTimesK u v) (contractAstTimesK u w)
3480-
contractAstTimesK (AstTimesK u v) w =
3481-
contractAstTimesK u (contractAstTimesK v w)
3482-
-- With static shapes, the second argument to QuotOp and RemOp
3483-
-- is often a constant, which makes such rules worth including,
3484-
-- since they are likely to fire. To help them fire, we avoid changing
3485-
-- that constant, if possible, e.g., in rules for NegateOp.
3486-
contractAstTimesK
3487-
(AstConcreteK v)
3488-
(Ast.AstI2K QuotOp (Ast.AstVar ftk2 var)
3489-
(AstConcreteK v')) | v == v' =
3490-
Ast.AstVar ftk2 var
3491-
+ contractAstNumOp1 NegateOp
3492-
(Ast.AstI2K RemOp (Ast.AstVar ftk2 var)
3493-
(AstConcreteK v))
3494-
contractAstTimesK u v = AstTimesK u v
3495-
34963436
contractAstIntegralOp2 :: (GoodScalar r, AstSpan s, IntegralF r)
34973437
=> OpCodeIntegral2
34983438
-> AstTensor AstMethodLet s (TKScalar r)
@@ -3506,7 +3446,7 @@ contractAstIntegralOp2 QuotOp (Ast.AstI2K RemOp _u (AstConcreteK v))
35063446
(AstConcreteK v')
35073447
| v' >= v && v >= 0 = 0
35083448
contractAstIntegralOp2 QuotOp (Ast.AstI2K QuotOp u v) w =
3509-
contractAstIntegralOp2 QuotOp u (contractAstTimesK v w)
3449+
contractAstIntegralOp2 QuotOp u (v * w)
35103450
contractAstIntegralOp2 QuotOp (AstTimesK (AstConcreteK u) v)
35113451
(AstConcreteK u')
35123452
| u == u' = v
@@ -3665,7 +3605,7 @@ substitute1Ast i var = subst where
36653605
let mu = subst u
36663606
mv = subst v
36673607
in if isJust mu || isJust mv
3668-
then Just $ contractAstTimesK (fromMaybe u mu) (fromMaybe v mv)
3608+
then Just $ fromMaybe u mu * fromMaybe v mv
36693609
else Nothing
36703610
Ast.AstN1K opCode u -> (\u2 -> contractAstNumOp1 opCode u2)
36713611
<$> subst u

src/HordeAd/Core/CarriersAst.hs

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,9 @@ instance (GoodScalar r, AstSpan s)
6262
u + AstConcreteK 0 = u
6363
AstConcreteK n + AstConcreteK k = AstConcreteK (n + k)
6464
AstConcreteK n + AstPlusK (AstConcreteK k) u = AstConcreteK (n + k) + u
65-
AstPlusK (AstConcreteK n) u + AstConcreteK k =
66-
AstConcreteK (n + k) + u
65+
AstPlusK (AstConcreteK n) u + AstConcreteK k = AstConcreteK (n + k) + u
6766
AstPlusK (AstConcreteK n) u + AstPlusK (AstConcreteK k) v =
6867
AstConcreteK (n + k) + AstPlusK u v -- u and v can cancel, but unlikely
69-
AstPlusK u@AstConcreteK{} v + w = AstPlusK u (AstPlusK v w) -- as above
7068

7169
-- Unfortunately, these only fire if the required subterms are at the top
7270
-- of the reduced term, which happens rarely except in small terms.
@@ -93,16 +91,75 @@ instance (GoodScalar r, AstSpan s)
9391
+ AstPlusK (AstI2K RemOp (AstN1K NegateOp (AstVar _ var)) (AstConcreteK n)) u
9492
| var == var' && n == n' = u
9593

94+
AstPlusK u@AstConcreteK{} v + w = AstPlusK u (AstPlusK v w) -- as above
9695
u + v@AstConcreteK{} = AstPlusK v u
9796
u + AstPlusK v@AstConcreteK{} w = AstPlusK v (AstPlusK u w) -- as above
9897
u + v = AstPlusK u v
9998

99+
AstConcreteK 0 * _ = 0
100+
_ * AstConcreteK 0 = 0
101+
AstConcreteK 1 * u = u
102+
u * AstConcreteK 1 = u
100103
AstConcreteK n * AstConcreteK k = AstConcreteK (n * k)
101-
AstConcreteK n * (AstTimesK (AstConcreteK k) u) =
102-
AstTimesK (AstConcreteK (n * k)) u
104+
AstConcreteK n * AstTimesK (AstConcreteK k) u = AstConcreteK (n * k) * u
105+
AstTimesK (AstConcreteK n) u * AstConcreteK k = AstConcreteK (n * k) * u
106+
AstTimesK (AstConcreteK n) u * AstTimesK (AstConcreteK k) v =
107+
AstConcreteK (n * k) * AstTimesK u v -- u and v can cancel, but unlikely
108+
109+
u@AstConcreteK{} * AstPlusK v w = AstPlusK (u * v) (u * w)
110+
AstTimesK u@AstConcreteK{} x * AstPlusK v w =
111+
AstTimesK x (AstPlusK (u * v) (u * w))
112+
AstPlusK v w * u@AstConcreteK{} = AstPlusK (v * u) (w * u)
113+
AstPlusK v w * AstTimesK u@AstConcreteK{} x =
114+
AstTimesK (AstPlusK (v * u) (w * u)) x
115+
116+
AstN1K NegateOp u * AstN1K NegateOp v = AstTimesK u v
117+
118+
-- With static shapes, the second argument to QuotOp and RemOp
119+
-- is often a constant, which makes such rules worth including,
120+
-- since they are likely to fire. To help them fire, we avoid changing
121+
-- that constant, if possible, e.g., in rules for NegateOp.
122+
AstConcreteK n * AstI2K QuotOp (AstVar ftk2 var) (AstConcreteK n')
123+
| n == n' =
124+
AstPlusK
125+
(AstVar ftk2 var)
126+
(negate (AstI2K RemOp (AstVar ftk2 var) (AstConcreteK n)))
127+
AstTimesK (AstConcreteK n) x * AstI2K QuotOp (AstVar ftk2 var)
128+
(AstConcreteK n')
129+
| n == n' =
130+
AstTimesK
131+
x
132+
(AstPlusK
133+
(AstVar ftk2 var)
134+
(negate (AstI2K RemOp (AstVar ftk2 var) (AstConcreteK n))))
135+
AstI2K QuotOp (AstVar ftk2 var) (AstConcreteK n') * AstConcreteK n
136+
| n == n' =
137+
AstPlusK
138+
(AstVar ftk2 var)
139+
(negate (AstI2K RemOp (AstVar ftk2 var) (AstConcreteK n)))
140+
AstI2K QuotOp (AstVar ftk2 var)
141+
(AstConcreteK n') * AstTimesK (AstConcreteK n) x
142+
| n == n' =
143+
AstTimesK
144+
(AstPlusK
145+
(AstVar ftk2 var)
146+
(negate (AstI2K RemOp (AstVar ftk2 var) (AstConcreteK n))))
147+
x
148+
149+
AstTimesK u@AstConcreteK{} v * w = AstTimesK u (AstTimesK v w) -- as above
150+
u * v@AstConcreteK{} = AstTimesK v u
151+
u * AstTimesK v@AstConcreteK{} w = AstTimesK v (AstTimesK u w) -- as above
103152
u * v = AstTimesK u v
104153

105154
negate (AstConcreteK n) = AstConcreteK (negate n)
155+
negate (AstPlusK u v) = AstPlusK (negate u) (negate v)
156+
negate (AstTimesK u v) = negate u * v
157+
negate (AstN1K NegateOp u) = u
158+
negate (AstN1K SignumOp u) = AstN1K SignumOp (negate u)
159+
negate (AstI2K QuotOp u v) = AstI2K QuotOp (negate u) v
160+
-- v is likely positive and let's keep it so
161+
negate (AstI2K RemOp u v) = AstI2K RemOp (negate u) v
162+
-- v is likely positive and let's keep it so
106163
negate u = AstN1K NegateOp u
107164
abs = AstN1K AbsOp
108165
signum = AstN1K SignumOp
@@ -215,7 +272,6 @@ instance GoodScalar r
215272
AstPlusS (AstConcreteS (n + k)) u
216273
AstPlusS (AstConcreteS n) u + AstPlusS (AstConcreteS k) v =
217274
AstPlusS (AstConcreteS (n + k)) (AstPlusS u v)
218-
AstPlusS u@AstConcreteS{} v + w = AstPlusS u (AstPlusS v w)
219275

220276
-- AstN1S NegateOp (AstVar _ var) + AstVar _ var'
221277
-- | var == var' = 0
@@ -226,16 +282,42 @@ instance GoodScalar r
226282
AstVar _ var' + AstPlusS (AstN1S NegateOp (AstVar _ var)) u
227283
| var == var' = u
228284

285+
AstPlusS u@AstConcreteS{} v + w = AstPlusS u (AstPlusS v w)
229286
u + v@AstConcreteS{} = AstPlusS v u
230287
u + AstPlusS v@AstConcreteS{} w = AstPlusS v (AstPlusS u w)
231288
u + v = AstPlusS u v
232289

233290
AstConcreteS n * AstConcreteS k = AstConcreteS (n * k)
234-
AstConcreteS n * (AstTimesS (AstConcreteS k) u) =
291+
AstConcreteS n * AstTimesS (AstConcreteS k) u =
235292
AstTimesS (AstConcreteS (n * k)) u
293+
AstTimesS (AstConcreteS n) u * AstConcreteS k =
294+
AstTimesS (AstConcreteS (n * k)) u
295+
AstTimesS (AstConcreteS n) u * AstTimesS (AstConcreteS k) v =
296+
AstTimesS (AstConcreteS (n * k)) (AstTimesS u v)
297+
298+
u@AstConcreteS{} * AstPlusS v w = AstPlusS (u * v) (u * w)
299+
AstTimesS u@AstConcreteS{} x * AstPlusS v w =
300+
AstTimesS x (AstPlusS (u * v) (u * w))
301+
AstPlusS v w * u@AstConcreteS{} = AstPlusS (v * u) (w * u)
302+
AstPlusS v w * AstTimesS u@AstConcreteS{} x =
303+
AstTimesS (AstPlusS (v * u) (w * u)) x
304+
305+
AstN1S NegateOp u * AstN1S NegateOp v = AstTimesS u v
306+
307+
AstTimesS u@AstConcreteS{} v * w = AstTimesS u (AstTimesS v w)
308+
u * v@AstConcreteS{} = AstTimesS v u
309+
u * AstTimesS v@AstConcreteS{} w = AstTimesS v (AstTimesS u w)
236310
u * v = AstTimesS u v
237311

238312
negate (AstConcreteS n) = AstConcreteS (negate n)
313+
negate (AstPlusS u v) = AstPlusS (negate u) (negate v)
314+
negate (AstTimesS u v) = AstTimesS (negate u) v
315+
negate (AstN1S NegateOp u) = u
316+
negate (AstN1S SignumOp u) = AstN1S SignumOp (negate u)
317+
negate (AstI2S QuotOp u v) = AstI2S QuotOp (negate u) v
318+
-- v is likely positive and let's keep it so
319+
negate (AstI2S RemOp u v) = AstI2S RemOp (negate u) v
320+
-- v is likely positive and let's keep it so
239321
negate u = AstN1S NegateOp u
240322
abs = AstN1S AbsOp
241323
signum = AstN1S SignumOp

0 commit comments

Comments
 (0)