@@ -887,7 +887,7 @@ astPrimalPart t = case t of
887
887
in astConcrete ftk (tconstantTarget 0 ftk)
888
888
889
889
AstPlusK u v -> astPrimalPart u + astPrimalPart v
890
- AstTimesK u v -> contractAstTimesK ( astPrimalPart u) ( astPrimalPart v)
890
+ AstTimesK u v -> astPrimalPart u * astPrimalPart v
891
891
AstN1K opCode u -> contractAstNumOp1 opCode (astPrimalPart u)
892
892
Ast. AstR1K opCode u -> Ast. AstR1K opCode (astPrimalPart u)
893
893
Ast. AstR2K opCode u v -> Ast. AstR2K opCode (astPrimalPart u) (astPrimalPart v)
@@ -970,7 +970,7 @@ astDualPart t = case t of
970
970
Ast. AstFromDual v -> v
971
971
972
972
AstPlusK u v -> astDualPart u + astDualPart v
973
- AstTimesK u v -> contractAstTimesK ( astDualPart u) ( astDualPart v)
973
+ AstTimesK u v -> astDualPart u * astDualPart v
974
974
AstN1K opCode u -> contractAstNumOp1 opCode (astDualPart u)
975
975
Ast. AstR1K opCode u -> Ast. AstR1K opCode (astDualPart u)
976
976
Ast. AstR2K opCode u v -> Ast. AstR2K opCode (astDualPart u) (astDualPart v)
@@ -1093,7 +1093,7 @@ astFromIntegralS t = case t of
1093
1093
Ast. AstBuild1 snat STKScalar (var, astFromIntegralK v)
1094
1094
AstConcreteS a -> AstConcreteS (tfromIntegralS a)
1095
1095
Ast. AstLet var u v -> astLet var u (astFromIntegralS v)
1096
- AstN1S NegateOp u -> negate (astFromIntegralS u)
1096
+ AstN1S NegateOp u -> negate (astFromIntegralS u)
1097
1097
AstN1S opCode u -> AstN1S opCode (astFromIntegralS u)
1098
1098
-- AstN2S opCode u v -> AstN2S opCode (astFromIntegralS u) (astFromIntegralS v)
1099
1099
-- Ast.AstI2S opCode u v ->
@@ -2585,7 +2585,7 @@ astNonIndexStep t = case t of
2585
2585
Ast. AstFromDual {} -> t
2586
2586
2587
2587
AstPlusK u v -> u + v
2588
- AstTimesK u v -> contractAstTimesK u v
2588
+ AstTimesK u v -> u * v
2589
2589
AstN1K opCode u -> contractAstNumOp1 opCode u
2590
2590
Ast. AstR1K {} -> t
2591
2591
Ast. AstR2K {} -> t
@@ -2689,7 +2689,7 @@ expandAst t = case t of
2689
2689
Ast. AstFromDual v -> Ast. AstFromDual (expandAst v)
2690
2690
2691
2691
AstPlusK u v -> expandAst u + expandAst v
2692
- AstTimesK u v -> contractAstTimesK ( expandAst u) ( expandAst v)
2692
+ AstTimesK u v -> expandAst u * expandAst v
2693
2693
AstN1K opCode u -> contractAstNumOp1 opCode (expandAst u)
2694
2694
Ast. AstR1K opCode u -> Ast. AstR1K opCode (expandAst u)
2695
2695
Ast. AstR2K opCode u v -> Ast. AstR2K opCode (expandAst u) (expandAst v)
@@ -2872,7 +2872,7 @@ simplifyAst t = case t of
2872
2872
Ast. AstFromDual v -> Ast. AstFromDual (simplifyAst v)
2873
2873
2874
2874
AstPlusK u v -> simplifyAst u + simplifyAst v
2875
- AstTimesK u v -> contractAstTimesK ( simplifyAst u) ( simplifyAst v)
2875
+ AstTimesK u v -> simplifyAst u * simplifyAst v
2876
2876
AstN1K opCode u -> contractAstNumOp1 opCode (simplifyAst u)
2877
2877
Ast. AstR1K opCode u -> Ast. AstR1K opCode (simplifyAst u)
2878
2878
Ast. AstR2K opCode u v -> Ast. AstR2K opCode (simplifyAst u) (simplifyAst v)
@@ -3227,7 +3227,7 @@ contractAst t = case t of
3227
3227
Ast. AstFromDual v -> Ast. AstFromDual (contractAst v)
3228
3228
3229
3229
AstPlusK u v -> contractAst u + contractAst v
3230
- AstTimesK u v -> contractAstTimesK ( contractAst u) ( contractAst v)
3230
+ AstTimesK u v -> contractAst u * contractAst v
3231
3231
AstN1K opCode u -> contractAstNumOp1 opCode (contractAst u)
3232
3232
Ast. AstR1K opCode u -> Ast. AstR1K opCode (contractAst u)
3233
3233
Ast. AstR2K opCode u v -> Ast. AstR2K opCode (contractAst u) (contractAst v)
@@ -3384,9 +3384,12 @@ contractRelOp GtOp (Ast.AstVar _ u) (Ast.AstVar _ v) | u == v =
3384
3384
AstBoolConst False
3385
3385
contractRelOp opCodeRel arg1 arg2 = Ast. AstRelK opCodeRel arg1 arg2
3386
3386
3387
+ -- TODO: perhaps aim for a polynomial normal form? but that requires global
3388
+ -- inspection of the whole expression
3387
3389
-- TODO: let's aim at SOP (Sum-of-Products) form, just as
3388
3390
-- ghc-typelits-natnormalise does. Also, let's associate to the right
3389
3391
-- and let's push negation down.
3392
+ -- TODO: these docs are outdated
3390
3393
--
3391
3394
-- | Normally, we wouldn't simplify tensor arithmetic so much, but some
3392
3395
-- of these ranked tensors can represent integers in indexes, so we have to.
@@ -3420,79 +3423,16 @@ contractRelOp opCodeRel arg1 arg2 = Ast.AstRelK opCodeRel arg1 arg2
3420
3423
contractAstNumOp1 :: (GoodScalar r , AstSpan s )
3421
3424
=> OpCodeNum1 -> AstTensor AstMethodLet s (TKScalar r )
3422
3425
-> AstTensor AstMethodLet s (TKScalar r )
3423
- contractAstNumOp1 NegateOp (AstConcreteK u) = AstConcreteK (negate u)
3424
- contractAstNumOp1 NegateOp (AstPlusK u v) =
3425
- AstPlusK (contractAstNumOp1 NegateOp u) (contractAstNumOp1 NegateOp v)
3426
- contractAstNumOp1 NegateOp (AstTimesK (AstConcreteK u) v) =
3427
- contractAstTimesK (AstConcreteK (negate u)) v
3428
- -- given a choice, prefer to negate a constant
3429
- contractAstNumOp1 NegateOp (AstTimesK u v) =
3430
- contractAstTimesK u (contractAstNumOp1 NegateOp v)
3431
- contractAstNumOp1 NegateOp (AstN1K NegateOp u) = u
3432
- contractAstNumOp1 NegateOp (AstN1K SignumOp u) =
3433
- contractAstNumOp1 SignumOp (contractAstNumOp1 NegateOp u)
3434
- contractAstNumOp1 NegateOp (Ast. AstI2K QuotOp u v) =
3435
- contractAstIntegralOp2 QuotOp (contractAstNumOp1 NegateOp u) v
3436
- -- v is likely positive and let's keep it so
3437
- contractAstNumOp1 NegateOp (Ast. AstI2K RemOp u v) =
3438
- contractAstIntegralOp2 RemOp (contractAstNumOp1 NegateOp u) v
3439
- -- v is likely positive and let's keep it so
3440
-
3426
+ contractAstNumOp1 NegateOp u = negate u
3441
3427
contractAstNumOp1 AbsOp (AstConcreteK u) = AstConcreteK (abs u)
3442
3428
contractAstNumOp1 AbsOp (AstN1K AbsOp u) = AstN1K AbsOp u
3443
3429
contractAstNumOp1 AbsOp (AstN1K NegateOp u) = contractAstNumOp1 AbsOp u
3444
3430
contractAstNumOp1 SignumOp (AstConcreteK u) = AstConcreteK (signum u)
3445
3431
contractAstNumOp1 SignumOp (AstN1K SignumOp u) = AstN1K SignumOp u
3446
3432
contractAstNumOp1 SignumOp (AstN1K AbsOp u) =
3447
3433
contractAstNumOp1 AbsOp (AstN1K SignumOp u)
3448
-
3449
3434
contractAstNumOp1 opCode u = AstN1K opCode u
3450
3435
3451
- -- As with AstPlusK, AstConcreteK is kept on the left.
3452
- contractAstTimesK :: (GoodScalar r , AstSpan s )
3453
- => AstTensor AstMethodLet s (TKScalar r )
3454
- -> AstTensor AstMethodLet s (TKScalar r )
3455
- -> AstTensor AstMethodLet s (TKScalar r )
3456
- contractAstTimesK (AstConcreteK 0 ) _v = AstConcreteK 0
3457
- contractAstTimesK (AstConcreteK 1 ) v = v
3458
- {- TODO: is it worth adding AstLet with a fresh variables
3459
- to share w and so make these rules safe? Perhaps after we decide
3460
- a normal form (e.g., a polynomial)?
3461
- contractAstTimesK (AstN2K PlusOp (u, v), w) =
3462
- contractAstTimesK ( contractAstTimesK (u, w)
3463
- , contractAstTimesK (v, w) )
3464
- contractAstTimesK (u, AstN2K PlusOp (v, w)) =
3465
- contractAstTimesK ( contractAstTimesK (u, v)
3466
- , contractAstTimesK (u, w) )
3467
- -}
3468
- contractAstTimesK (AstConcreteK u) (AstConcreteK v) = AstConcreteK (u * v)
3469
- contractAstTimesK (AstConcreteK u) (AstTimesK (AstConcreteK v) w) =
3470
- contractAstTimesK (AstConcreteK (u * v)) w
3471
- contractAstTimesK (AstTimesK (AstConcreteK u) v)
3472
- (AstTimesK (AstConcreteK w) x) =
3473
- AstTimesK (AstConcreteK (u * w)) (AstTimesK v x)
3474
- contractAstTimesK u w@ AstConcreteK {} = contractAstTimesK w u
3475
- contractAstTimesK u (AstTimesK v@ AstConcreteK {} w) =
3476
- contractAstTimesK v (AstTimesK u w)
3477
- -- TODO: perhaps aim for a polynomial normal form? but that requires global
3478
- -- inspection of the whole expression
3479
- contractAstTimesK u@ AstConcreteK {} (AstPlusK v w) = AstPlusK (contractAstTimesK u v) (contractAstTimesK u w)
3480
- contractAstTimesK (AstTimesK u v) w =
3481
- contractAstTimesK u (contractAstTimesK v w)
3482
- -- With static shapes, the second argument to QuotOp and RemOp
3483
- -- is often a constant, which makes such rules worth including,
3484
- -- since they are likely to fire. To help them fire, we avoid changing
3485
- -- that constant, if possible, e.g., in rules for NegateOp.
3486
- contractAstTimesK
3487
- (AstConcreteK v)
3488
- (Ast. AstI2K QuotOp (Ast. AstVar ftk2 var)
3489
- (AstConcreteK v')) | v == v' =
3490
- Ast. AstVar ftk2 var
3491
- + contractAstNumOp1 NegateOp
3492
- (Ast. AstI2K RemOp (Ast. AstVar ftk2 var)
3493
- (AstConcreteK v))
3494
- contractAstTimesK u v = AstTimesK u v
3495
-
3496
3436
contractAstIntegralOp2 :: (GoodScalar r , AstSpan s , IntegralF r )
3497
3437
=> OpCodeIntegral2
3498
3438
-> AstTensor AstMethodLet s (TKScalar r )
@@ -3506,7 +3446,7 @@ contractAstIntegralOp2 QuotOp (Ast.AstI2K RemOp _u (AstConcreteK v))
3506
3446
(AstConcreteK v')
3507
3447
| v' >= v && v >= 0 = 0
3508
3448
contractAstIntegralOp2 QuotOp (Ast. AstI2K QuotOp u v) w =
3509
- contractAstIntegralOp2 QuotOp u (contractAstTimesK v w)
3449
+ contractAstIntegralOp2 QuotOp u (v * w)
3510
3450
contractAstIntegralOp2 QuotOp (AstTimesK (AstConcreteK u) v)
3511
3451
(AstConcreteK u')
3512
3452
| u == u' = v
@@ -3665,7 +3605,7 @@ substitute1Ast i var = subst where
3665
3605
let mu = subst u
3666
3606
mv = subst v
3667
3607
in if isJust mu || isJust mv
3668
- then Just $ contractAstTimesK ( fromMaybe u mu) ( fromMaybe v mv)
3608
+ then Just $ fromMaybe u mu * fromMaybe v mv
3669
3609
else Nothing
3670
3610
Ast. AstN1K opCode u -> (\ u2 -> contractAstNumOp1 opCode u2)
3671
3611
<$> subst u
0 commit comments