Skip to content

Commit 05f80bd

Browse files
committed
Give mnistTrainBench2VTO the same type signature as mnistTestBench2VTA has
1 parent 40434aa commit 05f80bd

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

bench/common/BenchMnistTools.hs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ mnistBGroup2VTA chunkLength =
299299
\ xs ->
300300
bgroup ("2-hidden-layer rank 2 VTA MNIST nn with samples: "
301301
++ show chunkLength)
302-
[ mnistTestBench2VTA "30|10 "30 10 0.02 chunkLength xs
302+
[ mnistTestBench2VTA "30|10 " 30 10 0.02 chunkLength xs
303303
, mnistTrainBench2VTA "30|10 " 30 10 0.02 chunkLength xs
304304
, mnistTestBench2VTA "300|100 " 300 100 0.02 chunkLength xs
305305
, mnistTrainBench2VTA "300|100 " 300 100 0.02 chunkLength xs
@@ -336,7 +336,7 @@ mnistBGroup2VTC chunkLength =
336336
]
337337

338338
-- The same as above, but only runtime.
339-
mnistTrainBench2VTO
339+
mnistTrainBench2VTOO
340340
:: forall r. r ~ Double
341341
=> String
342342
-> Double -> Int -> [MnistDataLinearR r]
@@ -348,7 +348,7 @@ mnistTrainBench2VTO
348348
(TKR2 1 (TKScalar Double))))
349349
(TKScalar r) )
350350
-> Benchmark
351-
mnistTrainBench2VTO prefix gamma batchSize xs (targetInit, art) = do
351+
mnistTrainBench2VTOO prefix gamma batchSize xs (targetInit, art) = do
352352
let go :: [MnistDataLinearR r] -> Concrete (XParams2 r Float)
353353
-> Concrete (XParams2 r Float)
354354
go [] parameters = parameters
@@ -367,6 +367,21 @@ mnistTrainBench2VTO prefix gamma batchSize xs (targetInit, art) = do
367367
, "=" ++ show (tsize knownSTK targetInit) ]
368368
bench name $ nf gradf chunk
369369

370+
-- The same as above, but both compilation time and only runtime.
371+
mnistTrainBench2VTO
372+
:: forall r. r ~ Double
373+
=> String
374+
-> Int -> Int -> Double -> Int -> [MnistDataLinearR r]
375+
-> Benchmark
376+
mnistTrainBench2VTO prefix widthHidden widthHidden2
377+
gamma batchSize xs =
378+
let (!targetInit, !artRaw) =
379+
MnistFcnnRanked2.mnistTrainBench2VTOGradient
380+
@Double (Proxy @Float) IgnoreIncomingCotangent
381+
1 (mkStdGen 44) widthHidden widthHidden2
382+
!art = simplifyArtifactGradient artRaw
383+
in mnistTrainBench2VTOO prefix gamma batchSize xs (targetInit, art)
384+
370385
mnistBGroup2VTO :: Int -> Benchmark
371386
mnistBGroup2VTO chunkLength =
372387
let (!targetInit, !artRaw) =
@@ -381,7 +396,7 @@ mnistBGroup2VTO chunkLength =
381396
\ xs ->
382397
bgroup ("2-hidden-layer rank 2 VTO runtime MNIST nn with samples: "
383398
++ show chunkLength)
384-
[ mnistTrainBench2VTO "1500|500 " 0.02 chunkLength xs (targetInit, art)
399+
[ mnistTrainBench2VTOO "1500|500 " 0.02 chunkLength xs (targetInit, art)
385400
]
386401

387402
-- The same as above, but without simplifying the gradient.
@@ -429,7 +444,7 @@ mnistBGroup2VTOZ chunkLength =
429444
\ xs ->
430445
bgroup ("2-hidden-layer rank 2 VTOZ runtime MNIST nn with samples: "
431446
++ show chunkLength)
432-
[ mnistTrainBench2VTO "1500|500 " 0.02 chunkLength xs (targetInit, art)
447+
[ mnistTrainBench2VTOO "1500|500 " 0.02 chunkLength xs (targetInit, art)
433448
]
434449

435450
-- The same as above, but without any simplification, even the smart
@@ -478,7 +493,7 @@ mnistBGroup2VTOX chunkLength =
478493
\ xs ->
479494
bgroup ("2-hidden-layer rank 2 VTOX runtime MNIST nn with samples: "
480495
++ show chunkLength)
481-
[ mnistTrainBench2VTO "1500|500 " 0.02 chunkLength xs (targetInit, art)
496+
[ mnistTrainBench2VTOO "1500|500 " 0.02 chunkLength xs (targetInit, art)
482497
]
483498

484499
{- TODO: re-enable once -fpolymorphic-specialisation works

0 commit comments

Comments
 (0)