Skip to content

Commit f5ce66a

Browse files
committed
Don't make it too hard varying block size between training and testing data
1 parent f646172 commit f5ce66a

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

test/common/TestSpeechRNN.hs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ speechTestCaseRNN prefix epochs maxBatches trainWithLoss ftest flen expected =
289289
"/home/mikolaj/Downloads/spj_how_ai_really.float32.1.rms.bin"
290290
testData <-
291291
loadSpeechData
292-
@85 @10 @257 @2 -- the single batch covers the whole dataset
292+
@8 @block_size @window_size @n_labels
293+
-- With blocks size 100, this single batch covers most of the dataset.
294+
-- TODO: with block size 1, this results in tiny test data.
293295
"/home/mikolaj/Downloads/volleyball.float32.257.spectrogram.bin"
294296
"/home/mikolaj/Downloads/volleyball.float32.1.rms.bin"
295297
let testDataBatch = head testData
@@ -307,7 +309,7 @@ speechTestCaseRNN prefix epochs maxBatches trainWithLoss ftest flen expected =
307309
!testScore = ftest proxy_out_width testDataBatch parameters2
308310
!lenBatch = length batch
309311
hPutStrLn stderr $ printf "\n%s: (Batch %d with %d mini-batches)" prefix k lenBatch
310-
hPutStrLn stderr $ printf "%s: First batch training error: %.2f%%" prefix ((1 - trainScore) * 100)
312+
hPutStrLn stderr $ printf "%s: First mini-batch training error: %.2f%%" prefix ((1 - trainScore) * 100)
311313
hPutStrLn stderr $ printf "%s: Validation error: %.2f%%" prefix ((1 - testScore ) * 100)
312314
return res
313315
runEpoch :: Int -> (Domains r, StateAdam r) -> IO (Domains r)
@@ -341,19 +343,19 @@ mnistRNNTestsLong = testGroup "Speech RNN long tests"
341343
, speechTestCaseRNN @128 @64 @100 @257 @2
342344
"1 epoch, all batches" 1 9999
343345
rnnSpeechLossFused rnnSpeechTest rnnSpeechLen
344-
0.49411762
346+
0.25
345347
, speechTestCaseRNN @128 @64 @1 @257 @2
346348
"1 epoch, all batches, 1-wide blocks" 1 9999
347349
rnnSpeechLossFused rnnSpeechTest rnnSpeechLen
348-
0.19999999
350+
0.0
349351
, speechTestCaseRNN @128 @64 @100 @257 @2
350352
"10 epochs, all batches" 10 9999
351353
rnnSpeechLossFused rnnSpeechTest rnnSpeechLen
352-
0
354+
0 -- TODO
353355
, speechTestCaseRNN @128 @64 @1 @257 @2
354356
"10 epochs, all batches, 1-wide blocks" 10 9999
355357
rnnSpeechLossFused rnnSpeechTest rnnSpeechLen
356-
0
358+
0 -- TODO
357359
]
358360

359361
speechRNNTestsShort :: TestTree
@@ -369,7 +371,7 @@ speechRNNTestsShort = testGroup "Speech RNN short tests"
369371
, testCase "Load and sanity check testing speech files" $ do
370372
speechDataBatchList <-
371373
loadSpeechData
372-
@85 @10 @257 @2 @Float
374+
@8 @100 @257 @2 @Float
373375
"/home/mikolaj/Downloads/volleyball.float32.257.spectrogram.bin"
374376
"/home/mikolaj/Downloads/volleyball.float32.1.rms.bin"
375377
length speechDataBatchList @?= 1
@@ -379,5 +381,5 @@ speechRNNTestsShort = testGroup "Speech RNN short tests"
379381
maximum (map (OS.maximumA . snd) speechDataBatchList) @?= 1.0
380382
, speechTestCaseRNN @128 @64 @100 @257 @2 "1 epoch, 1 batch" 1 1
381383
rnnSpeechLossFused rnnSpeechTest rnnSpeechLen
382-
0.49411762
384+
0.25
383385
]

0 commit comments

Comments
 (0)