Skip to content

Commit 151b17d

Browse files
committed
Add more arithmetic rules for AstConvert
1 parent c12b53f commit 151b17d

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

src/HordeAd/Core/CarriersAst.hs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ instance (GoodScalar r, AstSpan s)
8888
=> Num (AstTensor ms s (TKScalar r)) where
8989
AstFromPrimal u + AstFromPrimal v = AstFromPrimal $ u + v
9090
AstFromDual u + AstFromDual v = AstFromDual $ u + v
91+
-- TODO: define a pattern synonym that captures the below. Also elsewhere.
92+
AstConvert c u + AstConvert _ v
93+
| FTKS ZSS x <- ftkAst u
94+
, FTKS ZSS y <- ftkAst v
95+
, Just Refl <- matchingFTK x (convertFTK c (ftkAst u))
96+
, Just Refl <- matchingFTK x y =
97+
AstConvert c $ u + v
9198
AstConcreteK 0 + u = u
9299
u + AstConcreteK 0 = u
93100
AstConcreteK n + AstConcreteK k = AstConcreteK (n + k)
@@ -166,6 +173,12 @@ instance (GoodScalar r, AstSpan s)
166173
_ * AstConcreteK 0 = 0
167174
AstConcreteK 1 * u = u
168175
u * AstConcreteK 1 = u
176+
AstConvert c u * AstConvert _ v
177+
| FTKS ZSS x <- ftkAst u
178+
, FTKS ZSS y <- ftkAst v
179+
, Just Refl <- matchingFTK x (convertFTK c (ftkAst u))
180+
, Just Refl <- matchingFTK x y =
181+
AstConvert c $ u * v
169182
AstConcreteK n * AstConcreteK k = AstConcreteK (n * k)
170183
AstConcreteK n * AstTimesK (AstConcreteK k) u = AstConcreteK (n * k) * u
171184
AstTimesK (AstConcreteK n) u * AstConcreteK k = AstConcreteK (n * k) * u
@@ -233,19 +246,30 @@ instance (GoodScalar r, AstSpan s)
233246
-- v is likely positive and let's keep it so
234247
negate (AstI2K RemOp u v) = AstI2K RemOp (negate u) v
235248
-- v is likely positive and let's keep it so
236-
-- TODO: negate (AstFromS' ftk u) = AstFromS ftk (negate u)
237249
negate (AstConcreteK n) = AstConcreteK (negate n)
250+
negate (AstConvert c u)
251+
| FTKS ZSS x <- ftkAst u
252+
, Just Refl <- matchingFTK x (convertFTK c (ftkAst u)) =
253+
AstConvert c (negate u)
238254
negate u = AstN1K NegateOp u
239255
abs (AstFromPrimal n) = AstFromPrimal (abs n)
240256
abs (AstFromDual n) = AstFromDual (abs n)
241257
abs (AstConcreteK n) = AstConcreteK (abs n)
242258
abs (AstN1K AbsOp u) = AstN1K AbsOp u
243259
abs (AstN1K NegateOp u) = abs u
260+
abs (AstConvert c u)
261+
| FTKS ZSS x <- ftkAst u
262+
, Just Refl <- matchingFTK x (convertFTK c (ftkAst u)) =
263+
AstConvert c (abs u)
244264
abs u = AstN1K AbsOp u
245265
signum (AstFromPrimal n) = AstFromPrimal (signum n)
246266
signum (AstFromDual n) = AstFromDual (signum n)
247267
signum (AstConcreteK n) = AstConcreteK (signum n)
248268
signum (AstN1K SignumOp u) = AstN1K SignumOp u
269+
signum (AstConvert c u)
270+
| FTKS ZSS x <- ftkAst u
271+
, Just Refl <- matchingFTK x (convertFTK c (ftkAst u)) =
272+
AstConvert c (signum u)
249273
signum u = AstN1K SignumOp u
250274
fromInteger i = fromPrimal $ AstConcreteK (fromInteger i)
251275
{-# SPECIALIZE instance Num (AstTensor ms FullSpan (TKScalar Int64)) #-}
@@ -294,6 +318,12 @@ eqK _ _ = False
294318
instance (GoodScalar r, IntegralH r, Nested.IntElt r, AstSpan s)
295319
=> IntegralH (AstTensor ms s (TKScalar r)) where
296320
quotH (AstFromPrimal n) (AstFromPrimal k) = AstFromPrimal (quotH n k)
321+
quotH (AstConvert c n) (AstConvert _ k)
322+
| FTKS ZSS x <- ftkAst n
323+
, FTKS ZSS y <- ftkAst k
324+
, Just Refl <- matchingFTK x (convertFTK c (ftkAst n))
325+
, Just Refl <- matchingFTK x y =
326+
AstConvert c (quotH n k)
297327
quotH (AstConcreteK n) (AstConcreteK k) = AstConcreteK (quotH n k)
298328
quotH (AstConcreteK 0) _ = 0
299329
quotH u (AstConcreteK 1) = u
@@ -308,6 +338,12 @@ instance (GoodScalar r, IntegralH r, Nested.IntElt r, AstSpan s)
308338
in if u1 == u2 then fromPrimal $ AstConcreteK u1 else t
309339

310340
remH (AstFromPrimal n) (AstFromPrimal k) = AstFromPrimal (remH n k)
341+
remH (AstConvert c n) (AstConvert _ k)
342+
| FTKS ZSS x <- ftkAst n
343+
, FTKS ZSS y <- ftkAst k
344+
, Just Refl <- matchingFTK x (convertFTK c (ftkAst n))
345+
, Just Refl <- matchingFTK x y =
346+
AstConvert c (remH n k)
311347
remH (AstConcreteK n) (AstConcreteK k) = AstConcreteK (remH n k)
312348
remH (AstConcreteK 0) _ = 0
313349
remH _ (AstConcreteK 1) = 0
@@ -586,7 +622,6 @@ instance (GoodScalar r, IntegralH r, Nested.IntElt r, AstSpan s)
586622
remH (AstReplicate snat stk@STKS{} u) (AstReplicate _ STKS{} v) =
587623
AstReplicate snat stk $ remH u v
588624
remH (AstFromPrimal n) (AstFromPrimal k) = AstFromPrimal (remH n k)
589-
-- TODO: define a pattern synonym that captures the below. Also elsewhere.
590625
remH (AstConvert c n) (AstConvert _ k)
591626
| FTKS ZSS x <- convertFTK c (ftkAst n)
592627
, Just Refl <- matchingFTK x (ftkAst n)

0 commit comments

Comments
 (0)