@@ -6,17 +6,21 @@ module TestSpeechRNN (testTrees, shortTestForCITrees) where
6
6
7
7
import Prelude
8
8
9
+ import Control.Exception (assert )
9
10
import Control.Monad (foldM )
10
11
import qualified Data.Array.DynamicS as OT
11
12
import Data.Array.Internal (valueOf )
12
13
import qualified Data.Array.ShapedS as OS
14
+ import qualified Data.ByteString.Lazy as LBS
13
15
import Data.List (foldl' , unfoldr )
14
16
import Data.Proxy (Proxy (Proxy ))
17
+ import Data.Serialize
15
18
import qualified Data.Vector.Generic as V
16
19
import GHC.TypeLits (KnownNat )
17
- import Numeric.LinearAlgebra (Matrix , Vector )
20
+ import Numeric.LinearAlgebra (Matrix , Numeric , Vector )
18
21
import qualified Numeric.LinearAlgebra as HM
19
- import System.IO (hPutStrLn , stderr )
22
+ import System.IO
23
+ (IOMode (ReadMode ), hPutStrLn , stderr , withBinaryFile )
20
24
import System.Random
21
25
import Test.Tasty
22
26
import Test.Tasty.HUnit hiding (assert )
@@ -37,34 +41,97 @@ shortTestForCITrees = [ speechRNNTestsShort
37
41
]
38
42
39
43
40
- type SpeechDataBatchS batch_size window_size n_labels r =
44
+ type SpeechDataBatch batch_size window_size n_labels r =
41
45
( OS. Array '[batch_size , window_size ] r
42
46
, OS. Array '[batch_size , n_labels ] r )
43
47
44
- speechTestCaseRNNS
48
+ chunksOf :: Int -> [e ] -> [[e ]]
49
+ chunksOf n = go where
50
+ go [] = []
51
+ go l = let (chunk, rest) = splitAt n l
52
+ in chunk : go rest
53
+
54
+ -- The last chunk is thrown away if smaller than batch size.
55
+ -- It crashes if the size of either file doesn't match the other.
56
+ -- TODO: perhaps then warn instead of failing an assertion.
57
+ -- TODO: perhaps warn about the last chunk, too.
58
+ -- TODO: this could be so much more elegant, e.g., if OS.fromList
59
+ -- returned the remaining list and so no manual size calculations would
60
+ -- be required.
61
+ -- TODO: performance, see https://github.com/schrammc/mnist-idx/blob/master/src/Data/IDX/Internal.hs
62
+ decodeSpeechData
63
+ :: forall batch_size window_size n_labels r .
64
+ ( Serialize r , Numeric r
65
+ , KnownNat batch_size , KnownNat window_size , KnownNat n_labels )
66
+ => Int -> LBS. ByteString -> LBS. ByteString
67
+ -> [SpeechDataBatch batch_size window_size n_labels r ]
68
+ decodeSpeechData len soundsBs labelsBs =
69
+ let soundsChunkSize = valueOf @ batch_size * valueOf @ window_size
70
+ labelsChunkSize = valueOf @ batch_size * valueOf @ n_labels
71
+ ! _A1 = assert (fromIntegral (LBS. length soundsBs) * labelsChunkSize
72
+ == fromIntegral (LBS. length labelsBs) * soundsChunkSize) ()
73
+ cutBs :: Int -> LBS. ByteString -> [[r ]]
74
+ cutBs chunkSize bs =
75
+ let list :: [r ] =
76
+ case decodeLazy
77
+ $ LBS. append (encodeLazy
78
+ $ len * chunkSize `div` valueOf @ batch_size )
79
+ bs of
80
+ Left err -> error err
81
+ Right l -> l
82
+ in filter (\ ch -> length ch >= chunkSize)
83
+ $ chunksOf chunkSize list
84
+ soundsChunks :: [[r ]] = cutBs soundsChunkSize soundsBs
85
+ labelsChunks :: [[r ]] = cutBs labelsChunkSize labelsBs
86
+ ! _A2 = assert (length soundsChunks > 0 ) ()
87
+ ! _A3 = assert (length soundsChunks == length labelsChunks) ()
88
+ makeSpeechDataBatch
89
+ :: [r ] -> [r ] -> SpeechDataBatch batch_size window_size n_labels r
90
+ makeSpeechDataBatch soundsCh labelsCh =
91
+ (OS. fromList soundsCh, OS. fromList labelsCh)
92
+ in zipWith makeSpeechDataBatch soundsChunks labelsChunks
93
+
94
+ loadSpeechData
95
+ :: forall batch_size window_size n_labels r .
96
+ ( Serialize r , Numeric r
97
+ , KnownNat batch_size , KnownNat window_size , KnownNat n_labels )
98
+ => Int -> FilePath -> FilePath
99
+ -> IO [SpeechDataBatch batch_size window_size n_labels r ]
100
+ loadSpeechData len soundsPath labelsPath =
101
+ withBinaryFile soundsPath ReadMode $ \ soundsHandle ->
102
+ withBinaryFile labelsPath ReadMode $ \ labelsHandle -> do
103
+ soundsContents <- LBS. hGetContents soundsHandle
104
+ labelsContents <- LBS. hGetContents labelsHandle
105
+ let ! _A1 = assert (LBS. length soundsContents > 0 ) ()
106
+ return $! decodeSpeechData len soundsContents labelsContents
107
+
108
+ speechTestCaseRNN
45
109
:: forall out_width batch_size window_size n_labels d r m .
46
110
( KnownNat out_width , KnownNat batch_size , KnownNat window_size
111
+ , KnownNat n_labels
47
112
, r ~ Double , d ~ 'DModeGradient, m ~ DualMonadGradient Double )
48
113
=> String
49
114
-> Int
50
115
-> Int
51
116
-> (forall out_width' batch_size' window_size' n_labels' .
52
- (DualMonad d r m , KnownNat out_width' , KnownNat batch_size' )
117
+ ( DualMonad d r m , KnownNat out_width' , KnownNat batch_size'
118
+ , KnownNat n_labels' )
53
119
=> Proxy out_width'
54
- -> SpeechDataBatchS batch_size' window_size' n_labels' r
120
+ -> SpeechDataBatch batch_size' window_size' n_labels' r
55
121
-> DualNumberVariables d r
56
122
-> m (DualNumber d r ))
57
123
-> (forall out_width' batch_size' window_size' n_labels' .
58
- (IsScalar d r , KnownNat out_width' , KnownNat batch_size' )
124
+ ( IsScalar d r , KnownNat out_width' , KnownNat batch_size'
125
+ , KnownNat n_labels' )
59
126
=> Proxy out_width'
60
- -> SpeechDataBatchS batch_size' window_size' n_labels' r
127
+ -> SpeechDataBatch batch_size' window_size' n_labels' r
61
128
-> Domains r
62
129
-> r )
63
130
-> (forall out_width' . KnownNat out_width'
64
131
=> Proxy out_width' -> (Int , [Int ], [(Int , Int )], [OT. ShapeL ]))
65
132
-> Double
66
133
-> TestTree
67
- speechTestCaseRNNS prefix epochs maxBatches trainWithLoss ftest flen expected =
134
+ speechTestCaseRNN prefix epochs maxBatches trainWithLoss ftest flen expected =
68
135
testCase prefix $
69
136
1.0 @?= 1.0
70
137
@@ -74,4 +141,12 @@ mnistRNNTestsLong = testGroup "Speech RNN long tests"
74
141
75
142
speechRNNTestsShort :: TestTree
76
143
speechRNNTestsShort = testGroup " Speech RNN short tests"
77
- []
144
+ [ testCase " Load and sanity check speech" $ do
145
+ speechDataBatchList <-
146
+ loadSpeechData
147
+ @ 64 @ 257 @ 1 @ Float
148
+ 859
149
+ " /home/mikolaj/Downloads/volleyball.float32.257.spectrogram.bin"
150
+ " /home/mikolaj/Downloads/volleyball.float32.1.rms.bin"
151
+ length speechDataBatchList @?= 859 `div` 64
152
+ ]
0 commit comments