Skip to content

Commit 7ed4e58

Browse files
committed
Actually load some sound and label files
1 parent ebeb9c9 commit 7ed4e58

File tree

2 files changed

+87
-10
lines changed

2 files changed

+87
-10
lines changed

horde-ad.cabal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ library testLibrary
204204
-- Other library packages from which modules are imported.
205205
build-depends:
206206
base
207+
, cereal
208+
, bytestring
207209
, deepseq
208210
, HUnit-approx
209211
, hmatrix

test/common/TestSpeechRNN.hs

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,21 @@ module TestSpeechRNN (testTrees, shortTestForCITrees) where
66

77
import Prelude
88

9+
import Control.Exception (assert)
910
import Control.Monad (foldM)
1011
import qualified Data.Array.DynamicS as OT
1112
import Data.Array.Internal (valueOf)
1213
import qualified Data.Array.ShapedS as OS
14+
import qualified Data.ByteString.Lazy as LBS
1315
import Data.List (foldl', unfoldr)
1416
import Data.Proxy (Proxy (Proxy))
17+
import Data.Serialize
1518
import qualified Data.Vector.Generic as V
1619
import GHC.TypeLits (KnownNat)
17-
import Numeric.LinearAlgebra (Matrix, Vector)
20+
import Numeric.LinearAlgebra (Matrix, Numeric, Vector)
1821
import qualified Numeric.LinearAlgebra as HM
19-
import System.IO (hPutStrLn, stderr)
22+
import System.IO
23+
(IOMode (ReadMode), hPutStrLn, stderr, withBinaryFile)
2024
import System.Random
2125
import Test.Tasty
2226
import Test.Tasty.HUnit hiding (assert)
@@ -37,34 +41,97 @@ shortTestForCITrees = [ speechRNNTestsShort
3741
]
3842

3943

40-
type SpeechDataBatchS batch_size window_size n_labels r =
44+
type SpeechDataBatch batch_size window_size n_labels r =
4145
( OS.Array '[batch_size, window_size] r
4246
, OS.Array '[batch_size, n_labels] r )
4347

44-
speechTestCaseRNNS
48+
chunksOf :: Int -> [e] -> [[e]]
49+
chunksOf n = go where
50+
go [] = []
51+
go l = let (chunk, rest) = splitAt n l
52+
in chunk : go rest
53+
54+
-- The last chunk is thrown away if smaller than batch size.
55+
-- It crashes if the size of either file doesn't match the other.
56+
-- TODO: perhaps then warn instead of failing an assertion.
57+
-- TODO: perhaps warn about the last chunk, too.
58+
-- TODO: this could be so much more elegant, e.g., if OS.fromList
59+
-- returned the remaining list and so no manual size calculations would
60+
-- be required.
61+
-- TODO: performance, see https://github.com/schrammc/mnist-idx/blob/master/src/Data/IDX/Internal.hs
62+
decodeSpeechData
63+
:: forall batch_size window_size n_labels r.
64+
( Serialize r, Numeric r
65+
, KnownNat batch_size, KnownNat window_size, KnownNat n_labels )
66+
=> Int -> LBS.ByteString -> LBS.ByteString
67+
-> [SpeechDataBatch batch_size window_size n_labels r]
68+
decodeSpeechData len soundsBs labelsBs =
69+
let soundsChunkSize = valueOf @batch_size * valueOf @window_size
70+
labelsChunkSize = valueOf @batch_size * valueOf @n_labels
71+
!_A1 = assert (fromIntegral (LBS.length soundsBs) * labelsChunkSize
72+
== fromIntegral (LBS.length labelsBs) * soundsChunkSize) ()
73+
cutBs :: Int -> LBS.ByteString -> [[r]]
74+
cutBs chunkSize bs =
75+
let list :: [r] =
76+
case decodeLazy
77+
$ LBS.append (encodeLazy
78+
$ len * chunkSize `div` valueOf @batch_size)
79+
bs of
80+
Left err -> error err
81+
Right l -> l
82+
in filter (\ch -> length ch >= chunkSize)
83+
$ chunksOf chunkSize list
84+
soundsChunks :: [[r]] = cutBs soundsChunkSize soundsBs
85+
labelsChunks :: [[r]] = cutBs labelsChunkSize labelsBs
86+
!_A2 = assert (length soundsChunks > 0) ()
87+
!_A3 = assert (length soundsChunks == length labelsChunks) ()
88+
makeSpeechDataBatch
89+
:: [r] -> [r] -> SpeechDataBatch batch_size window_size n_labels r
90+
makeSpeechDataBatch soundsCh labelsCh =
91+
(OS.fromList soundsCh, OS.fromList labelsCh)
92+
in zipWith makeSpeechDataBatch soundsChunks labelsChunks
93+
94+
loadSpeechData
95+
:: forall batch_size window_size n_labels r.
96+
( Serialize r, Numeric r
97+
, KnownNat batch_size, KnownNat window_size, KnownNat n_labels )
98+
=> Int -> FilePath -> FilePath
99+
-> IO [SpeechDataBatch batch_size window_size n_labels r]
100+
loadSpeechData len soundsPath labelsPath =
101+
withBinaryFile soundsPath ReadMode $ \soundsHandle ->
102+
withBinaryFile labelsPath ReadMode $ \labelsHandle -> do
103+
soundsContents <- LBS.hGetContents soundsHandle
104+
labelsContents <- LBS.hGetContents labelsHandle
105+
let !_A1 = assert (LBS.length soundsContents > 0) ()
106+
return $! decodeSpeechData len soundsContents labelsContents
107+
108+
speechTestCaseRNN
45109
:: forall out_width batch_size window_size n_labels d r m.
46110
( KnownNat out_width, KnownNat batch_size, KnownNat window_size
111+
, KnownNat n_labels
47112
, r ~ Double, d ~ 'DModeGradient, m ~ DualMonadGradient Double )
48113
=> String
49114
-> Int
50115
-> Int
51116
-> (forall out_width' batch_size' window_size' n_labels'.
52-
(DualMonad d r m, KnownNat out_width', KnownNat batch_size')
117+
( DualMonad d r m, KnownNat out_width', KnownNat batch_size'
118+
, KnownNat n_labels' )
53119
=> Proxy out_width'
54-
-> SpeechDataBatchS batch_size' window_size' n_labels' r
120+
-> SpeechDataBatch batch_size' window_size' n_labels' r
55121
-> DualNumberVariables d r
56122
-> m (DualNumber d r))
57123
-> (forall out_width' batch_size' window_size' n_labels'.
58-
(IsScalar d r, KnownNat out_width', KnownNat batch_size')
124+
( IsScalar d r, KnownNat out_width', KnownNat batch_size'
125+
, KnownNat n_labels' )
59126
=> Proxy out_width'
60-
-> SpeechDataBatchS batch_size' window_size' n_labels' r
127+
-> SpeechDataBatch batch_size' window_size' n_labels' r
61128
-> Domains r
62129
-> r)
63130
-> (forall out_width'. KnownNat out_width'
64131
=> Proxy out_width' -> (Int, [Int], [(Int, Int)], [OT.ShapeL]))
65132
-> Double
66133
-> TestTree
67-
speechTestCaseRNNS prefix epochs maxBatches trainWithLoss ftest flen expected =
134+
speechTestCaseRNN prefix epochs maxBatches trainWithLoss ftest flen expected =
68135
testCase prefix $
69136
1.0 @?= 1.0
70137

@@ -74,4 +141,12 @@ mnistRNNTestsLong = testGroup "Speech RNN long tests"
74141

75142
speechRNNTestsShort :: TestTree
76143
speechRNNTestsShort = testGroup "Speech RNN short tests"
77-
[]
144+
[ testCase "Load and sanity check speech" $ do
145+
speechDataBatchList <-
146+
loadSpeechData
147+
@64 @257 @1 @Float
148+
859
149+
"/home/mikolaj/Downloads/volleyball.float32.257.spectrogram.bin"
150+
"/home/mikolaj/Downloads/volleyball.float32.1.rms.bin"
151+
length speechDataBatchList @?= 859 `div` 64
152+
]

0 commit comments

Comments
 (0)