Skip to content

Commit ad33ed4

Browse files
committed
Add a shorter variant of benchmark benchProd
1 parent 6dc0a7d commit ad33ed4

File tree

1 file changed

+42
-13
lines changed

1 file changed

+42
-13
lines changed

bench/common/BenchProdTools.hs

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ import HordeAd.Core.Ops
2727
bgroup100, bgroup1000, bgroup1e4, bgroup1e5, bgroup1e6, bgroup1e7, bgroup5e7 :: [Double] -> Benchmark
2828
bgroup100 = envProd 100 $ \args -> bgroup "100" $ benchProd args
2929

30-
bgroup1000 = envProd 1000 $ \args -> bgroup "1000" $ benchProd args
30+
bgroup1000 = envProd 1000 $ \args -> bgroup "1000" $ benchProdShort args
3131

32-
bgroup1e4 = envProd 1e4 $ \args -> bgroup "1e4" $ benchProd args
32+
bgroup1e4 = envProd 1e4 $ \args -> bgroup "1e4" $ benchProdShort args
3333

34-
bgroup1e5 = envProd 1e5 $ \args -> bgroup "1e5" $ benchProd args
34+
bgroup1e5 = envProd 1e5 $ \args -> bgroup "1e5" $ benchProdShort args
3535

36-
bgroup1e6 = envProd 1e6 $ \args -> bgroup "1e6" $ benchProd args
36+
bgroup1e6 = envProd 1e6 $ \args -> bgroup "1e6" $ benchProdShort args
3737

38-
bgroup1e7 = envProd 1e7 $ \args -> bgroup "1e7" $ benchProd args
38+
bgroup1e7 = envProd 1e7 $ \args -> bgroup "1e7" $ benchProdShort args
3939

40-
bgroup5e7 = envProd 5e7 $ \args -> bgroup "5e7" $ benchProd args
40+
bgroup5e7 = envProd 5e7 $ \args -> bgroup "5e7" $ benchProdShort args
4141
-- 5e7 == 5 * 10^7 == 0.5 * 10^8 == 0.5e8
4242

4343
envProd :: r ~ Double
@@ -64,13 +64,14 @@ envProd rat f allxs =
6464
, sfromList . fromList $ lt) )
6565
(f @k)
6666

67-
benchProd :: r ~ Double
68-
=> ( SNat n
69-
, [Concrete (TKScalar r)]
70-
, ListR n (Concrete (TKScalar r))
71-
, ListR n (Concrete (TKS '[] r))
72-
, Concrete (TKS '[n] r) )
73-
-> [Benchmark]
67+
benchProd
68+
:: r ~ Double
69+
=> ( SNat n
70+
, [Concrete (TKScalar r)]
71+
, ListR n (Concrete (TKScalar r))
72+
, ListR n (Concrete (TKS '[] r))
73+
, Concrete (TKS '[n] r) )
74+
-> [Benchmark]
7475
benchProd ~(snat, list, l, lt, t) = case snat of
7576
SNat ->
7677
[ bench "cgrad s MapAccum" $ nf (crevSMapAccum snat) t
@@ -91,6 +92,34 @@ benchProd ~(snat, list, l, lt, t) = case snat of
9192
, bench "cgrad s NotShared" $ nf (crevSNotShared snat) lt
9293
]
9394

95+
benchProdShort
96+
:: r ~ Double
97+
=> ( SNat n
98+
, [Concrete (TKScalar r)]
99+
, ListR n (Concrete (TKScalar r))
100+
, ListR n (Concrete (TKS '[] r))
101+
, Concrete (TKS '[n] r) )
102+
-> [Benchmark]
103+
benchProdShort ~(snat, list, l, lt, t) = case snat of
104+
SNat ->
105+
-- [ bench "cgrad s MapAccum" $ nf (crevSMapAccum snat) t
106+
-- , bench "grad s MapAccum" $ nf (revSMapAccum snat) t
107+
[ bench "cgrad scalar MapAccum" $ nf (crevScalarMapAccum snat) t
108+
, bench "grad scalar MapAccum" $ nf (revScalarMapAccum snat) t
109+
-- , bench "cgrad scalar list" $ nf crevScalarList list
110+
-- , bench "grad scalar list" $ nf revScalarList list
111+
, bench "cgrad scalar L" $ nf (crevScalarL snat) l
112+
-- , bench "grad scalar L" $ nf (revScalarL snat) l
113+
, bench "cgrad scalar R" $ nf (crevScalarR snat) l
114+
-- , bench "grad scalar R" $ nf (revScalarR snat) l
115+
, bench "cgrad scalar NotShared" $ nf (crevScalarNotShared snat) l
116+
-- , bench "cgrad s L" $ nf (crevSL snat) lt
117+
-- , bench "grad s L" $ nf (revSL snat) lt
118+
-- , bench "cgrad s R" $ nf (crevSR snat) lt
119+
-- , bench "grad s R" $ nf (revSR snat) lt
120+
-- , bench "cgrad s NotShared" $ nf (crevSNotShared snat) lt
121+
]
122+
94123
-- Another variant, with foldl1' and indexing, would be a disaster.
95124
-- We can define sproduct if this benchmark ends up used anywhere,
96125
-- because the current codomain of gradientFromDelta rules out

0 commit comments

Comments
 (0)