@@ -299,7 +299,7 @@ mnistBGroup2VTA chunkLength =
299
299
\ xs ->
300
300
bgroup (" 2-hidden-layer rank 2 VTA MNIST nn with samples: "
301
301
++ show chunkLength)
302
- [ mnistTestBench2VTA " 30|10 " 30 10 0.02 chunkLength xs
302
+ [ mnistTestBench2VTA " 30|10 " 30 10 0.02 chunkLength xs
303
303
, mnistTrainBench2VTA " 30|10 " 30 10 0.02 chunkLength xs
304
304
, mnistTestBench2VTA " 300|100 " 300 100 0.02 chunkLength xs
305
305
, mnistTrainBench2VTA " 300|100 " 300 100 0.02 chunkLength xs
@@ -336,7 +336,7 @@ mnistBGroup2VTC chunkLength =
336
336
]
337
337
338
338
-- The same as above, but only runtime.
339
- mnistTrainBench2VTO
339
+ mnistTrainBench2VTOO
340
340
:: forall r . r ~ Double
341
341
=> String
342
342
-> Double -> Int -> [MnistDataLinearR r ]
@@ -348,7 +348,7 @@ mnistTrainBench2VTO
348
348
(TKR2 1 (TKScalar Double ))))
349
349
(TKScalar r ) )
350
350
-> Benchmark
351
- mnistTrainBench2VTO prefix gamma batchSize xs (targetInit, art) = do
351
+ mnistTrainBench2VTOO prefix gamma batchSize xs (targetInit, art) = do
352
352
let go :: [MnistDataLinearR r ] -> Concrete (XParams2 r Float )
353
353
-> Concrete (XParams2 r Float )
354
354
go [] parameters = parameters
@@ -367,6 +367,21 @@ mnistTrainBench2VTO prefix gamma batchSize xs (targetInit, art) = do
367
367
, " =" ++ show (tsize knownSTK targetInit) ]
368
368
bench name $ nf gradf chunk
369
369
370
+ -- The same as above, but both compilation time and only runtime.
371
+ mnistTrainBench2VTO
372
+ :: forall r . r ~ Double
373
+ => String
374
+ -> Int -> Int -> Double -> Int -> [MnistDataLinearR r ]
375
+ -> Benchmark
376
+ mnistTrainBench2VTO prefix widthHidden widthHidden2
377
+ gamma batchSize xs =
378
+ let (! targetInit, ! artRaw) =
379
+ MnistFcnnRanked2. mnistTrainBench2VTOGradient
380
+ @ Double (Proxy @ Float ) IgnoreIncomingCotangent
381
+ 1 (mkStdGen 44 ) widthHidden widthHidden2
382
+ ! art = simplifyArtifactGradient artRaw
383
+ in mnistTrainBench2VTOO prefix gamma batchSize xs (targetInit, art)
384
+
370
385
mnistBGroup2VTO :: Int -> Benchmark
371
386
mnistBGroup2VTO chunkLength =
372
387
let (! targetInit, ! artRaw) =
@@ -381,7 +396,7 @@ mnistBGroup2VTO chunkLength =
381
396
\ xs ->
382
397
bgroup (" 2-hidden-layer rank 2 VTO runtime MNIST nn with samples: "
383
398
++ show chunkLength)
384
- [ mnistTrainBench2VTO " 1500|500 " 0.02 chunkLength xs (targetInit, art)
399
+ [ mnistTrainBench2VTOO " 1500|500 " 0.02 chunkLength xs (targetInit, art)
385
400
]
386
401
387
402
-- The same as above, but without simplifying the gradient.
@@ -429,7 +444,7 @@ mnistBGroup2VTOZ chunkLength =
429
444
\ xs ->
430
445
bgroup (" 2-hidden-layer rank 2 VTOZ runtime MNIST nn with samples: "
431
446
++ show chunkLength)
432
- [ mnistTrainBench2VTO " 1500|500 " 0.02 chunkLength xs (targetInit, art)
447
+ [ mnistTrainBench2VTOO " 1500|500 " 0.02 chunkLength xs (targetInit, art)
433
448
]
434
449
435
450
-- The same as above, but without any simplification, even the smart
@@ -478,7 +493,7 @@ mnistBGroup2VTOX chunkLength =
478
493
\ xs ->
479
494
bgroup (" 2-hidden-layer rank 2 VTOX runtime MNIST nn with samples: "
480
495
++ show chunkLength)
481
- [ mnistTrainBench2VTO " 1500|500 " 0.02 chunkLength xs (targetInit, art)
496
+ [ mnistTrainBench2VTOO " 1500|500 " 0.02 chunkLength xs (targetInit, art)
482
497
]
483
498
484
499
{- TODO: re-enable once -fpolymorphic-specialisation works
0 commit comments