@@ -88,6 +88,13 @@ instance (GoodScalar r, AstSpan s)
88
88
=> Num (AstTensor ms s (TKScalar r )) where
89
89
AstFromPrimal u + AstFromPrimal v = AstFromPrimal $ u + v
90
90
AstFromDual u + AstFromDual v = AstFromDual $ u + v
91
+ -- TODO: define a pattern synonym that captures the below. Also elsewhere.
92
+ AstConvert c u + AstConvert _ v
93
+ | FTKS ZSS x <- ftkAst u
94
+ , FTKS ZSS y <- ftkAst v
95
+ , Just Refl <- matchingFTK x (convertFTK c (ftkAst u))
96
+ , Just Refl <- matchingFTK x y =
97
+ AstConvert c $ u + v
91
98
AstConcreteK 0 + u = u
92
99
u + AstConcreteK 0 = u
93
100
AstConcreteK n + AstConcreteK k = AstConcreteK (n + k)
@@ -166,6 +173,12 @@ instance (GoodScalar r, AstSpan s)
166
173
_ * AstConcreteK 0 = 0
167
174
AstConcreteK 1 * u = u
168
175
u * AstConcreteK 1 = u
176
+ AstConvert c u * AstConvert _ v
177
+ | FTKS ZSS x <- ftkAst u
178
+ , FTKS ZSS y <- ftkAst v
179
+ , Just Refl <- matchingFTK x (convertFTK c (ftkAst u))
180
+ , Just Refl <- matchingFTK x y =
181
+ AstConvert c $ u * v
169
182
AstConcreteK n * AstConcreteK k = AstConcreteK (n * k)
170
183
AstConcreteK n * AstTimesK (AstConcreteK k) u = AstConcreteK (n * k) * u
171
184
AstTimesK (AstConcreteK n) u * AstConcreteK k = AstConcreteK (n * k) * u
@@ -233,19 +246,30 @@ instance (GoodScalar r, AstSpan s)
233
246
-- v is likely positive and let's keep it so
234
247
negate (AstI2K RemOp u v) = AstI2K RemOp (negate u) v
235
248
-- v is likely positive and let's keep it so
236
- -- TODO: negate (AstFromS' ftk u) = AstFromS ftk (negate u)
237
249
negate (AstConcreteK n) = AstConcreteK (negate n)
250
+ negate (AstConvert c u)
251
+ | FTKS ZSS x <- ftkAst u
252
+ , Just Refl <- matchingFTK x (convertFTK c (ftkAst u)) =
253
+ AstConvert c (negate u)
238
254
negate u = AstN1K NegateOp u
239
255
abs (AstFromPrimal n) = AstFromPrimal (abs n)
240
256
abs (AstFromDual n) = AstFromDual (abs n)
241
257
abs (AstConcreteK n) = AstConcreteK (abs n)
242
258
abs (AstN1K AbsOp u) = AstN1K AbsOp u
243
259
abs (AstN1K NegateOp u) = abs u
260
+ abs (AstConvert c u)
261
+ | FTKS ZSS x <- ftkAst u
262
+ , Just Refl <- matchingFTK x (convertFTK c (ftkAst u)) =
263
+ AstConvert c (abs u)
244
264
abs u = AstN1K AbsOp u
245
265
signum (AstFromPrimal n) = AstFromPrimal (signum n)
246
266
signum (AstFromDual n) = AstFromDual (signum n)
247
267
signum (AstConcreteK n) = AstConcreteK (signum n)
248
268
signum (AstN1K SignumOp u) = AstN1K SignumOp u
269
+ signum (AstConvert c u)
270
+ | FTKS ZSS x <- ftkAst u
271
+ , Just Refl <- matchingFTK x (convertFTK c (ftkAst u)) =
272
+ AstConvert c (signum u)
249
273
signum u = AstN1K SignumOp u
250
274
fromInteger i = fromPrimal $ AstConcreteK (fromInteger i)
251
275
{-# SPECIALIZE instance Num (AstTensor ms FullSpan (TKScalar Int64)) #-}
@@ -294,6 +318,12 @@ eqK _ _ = False
294
318
instance (GoodScalar r , IntegralH r , Nested. IntElt r , AstSpan s )
295
319
=> IntegralH (AstTensor ms s (TKScalar r )) where
296
320
quotH (AstFromPrimal n) (AstFromPrimal k) = AstFromPrimal (quotH n k)
321
+ quotH (AstConvert c n) (AstConvert _ k)
322
+ | FTKS ZSS x <- ftkAst n
323
+ , FTKS ZSS y <- ftkAst k
324
+ , Just Refl <- matchingFTK x (convertFTK c (ftkAst n))
325
+ , Just Refl <- matchingFTK x y =
326
+ AstConvert c (quotH n k)
297
327
quotH (AstConcreteK n) (AstConcreteK k) = AstConcreteK (quotH n k)
298
328
quotH (AstConcreteK 0 ) _ = 0
299
329
quotH u (AstConcreteK 1 ) = u
@@ -308,6 +338,12 @@ instance (GoodScalar r, IntegralH r, Nested.IntElt r, AstSpan s)
308
338
in if u1 == u2 then fromPrimal $ AstConcreteK u1 else t
309
339
310
340
remH (AstFromPrimal n) (AstFromPrimal k) = AstFromPrimal (remH n k)
341
+ remH (AstConvert c n) (AstConvert _ k)
342
+ | FTKS ZSS x <- ftkAst n
343
+ , FTKS ZSS y <- ftkAst k
344
+ , Just Refl <- matchingFTK x (convertFTK c (ftkAst n))
345
+ , Just Refl <- matchingFTK x y =
346
+ AstConvert c (remH n k)
311
347
remH (AstConcreteK n) (AstConcreteK k) = AstConcreteK (remH n k)
312
348
remH (AstConcreteK 0 ) _ = 0
313
349
remH _ (AstConcreteK 1 ) = 0
@@ -586,7 +622,6 @@ instance (GoodScalar r, IntegralH r, Nested.IntElt r, AstSpan s)
586
622
remH (AstReplicate snat stk@ STKS {} u) (AstReplicate _ STKS {} v) =
587
623
AstReplicate snat stk $ remH u v
588
624
remH (AstFromPrimal n) (AstFromPrimal k) = AstFromPrimal (remH n k)
589
- -- TODO: define a pattern synonym that captures the below. Also elsewhere.
590
625
remH (AstConvert c n) (AstConvert _ k)
591
626
| FTKS ZSS x <- convertFTK c (ftkAst n)
592
627
, Just Refl <- matchingFTK x (ftkAst n)
0 commit comments