1
+ import argparse
2
+ import collections
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch .utils .data as data_utils
8
+
9
+ import pandas as pd
10
+
11
+ from deconv .gmm .sgd_deconv_gmm import SGDDeconvGMM
12
+ from deconv .gmm .sgd_gmm import SGDGMM
13
+ from deconv .gmm .data import DeconvDataset
14
+ from deconv .flow .svi import SVIFlow
15
+
16
+ parser = argparse .ArgumentParser (description = 'Process SVI on GMM results.' )
17
+
18
+ parser .add_argument ('pretrain_results_dir' )
19
+ parser .add_argument ('elbo_results_dir' )
20
+ parser .add_argument ('iw_results_dir' )
21
+
22
+ args = parser .parse_args ()
23
+
24
+ K = 4
25
+ D = 2
26
+ N = 50000
27
+
28
+ torch .set_default_tensor_type (torch .FloatTensor )
29
+
30
+ ref_gmm = SGDDeconvGMM (
31
+ K ,
32
+ D ,
33
+ batch_size = 512 ,
34
+ device = torch .device ('cpu' )
35
+ )
36
+
37
+ ref_gmm .module .soft_weights .data = torch .zeros (K )
38
+ scale = 2
39
+
40
+ ref_gmm .module .means .data = torch .Tensor ([
41
+ [- scale , 0 ],
42
+ [scale , 0 ],
43
+ [0 , - scale ],
44
+ [0 , scale ]
45
+ ])
46
+
47
+ short_std = 0.3
48
+ long_std = 1
49
+
50
+ stds = torch .Tensor ([
51
+ [short_std , long_std ],
52
+ [short_std , long_std ],
53
+ [long_std , short_std ],
54
+ [long_std , short_std ]
55
+ ])
56
+
57
+ ref_gmm .module .l_diag .data = torch .log (stds )
58
+
59
+ torch .manual_seed (263568 )
60
+
61
+ z_test = ref_gmm .sample_prior (N )
62
+
63
+ noise_short = 0.1
64
+ noise_long = 1.0
65
+
66
+ S = torch .Tensor ([
67
+ [noise_short , 0 ],
68
+ [0 , noise_long ]
69
+ ])
70
+
71
+ noise_distribution = torch .distributions .MultivariateNormal (
72
+ loc = torch .Tensor ([0 , 0 ]),
73
+ covariance_matrix = S
74
+ )
75
+
76
+ x_test = z_test + noise_distribution .sample ([N ])
77
+ test_data = DeconvDataset (x_test .squeeze (), S .repeat (N , 1 , 1 ))
78
+
79
+ pretrained_params = []
80
+ for f in os .listdir (args .pretrain_results_dir ):
81
+ path = os .path .join (args .pretrain_results_dir , f )
82
+ pretrained_params .append (path )
83
+
84
+ elbo_params = []
85
+ for f in os .listdir (args .elbo_results_dir ):
86
+ path = os .path .join (args .elbo_results_dir , f )
87
+ elbo_params .append (path )
88
+
89
+ iw_params = []
90
+ for f in os .listdir (args .iw_results_dir ):
91
+ path = os .path .join (args .iw_results_dir , f )
92
+ iw_params .append (path )
93
+
94
+ svi = SVIFlow (
95
+ 2 ,
96
+ 5 ,
97
+ device = torch .device ('cuda' ),
98
+ batch_size = 512 ,
99
+ epochs = 100 ,
100
+ lr = 1e-4 ,
101
+ n_samples = 50 ,
102
+ use_iwae = False ,
103
+ context_size = 64 ,
104
+ hidden_features = 128
105
+ )
106
+
107
+ results = []
108
+
109
+ test_data = DeconvDataset (x_test .squeeze (), torch .cholesky (S .repeat (N , 1 , 1 )))
110
+
111
+ torch .set_default_tensor_type (torch .cuda .FloatTensor )
112
+
113
+ for p in pretrained_params :
114
+ svi .model .load_state_dict (torch .load (p ))
115
+ with torch .no_grad ():
116
+ logv = svi .model ._prior .log_prob (z_test [0 ].to (torch .device ('cuda' ))).mean ().item ()
117
+ elbo = svi .score_batch (test_data , num_samples = 100 ) / N
118
+ logp = svi .score_batch (test_data , num_samples = 100 , log_prob = True ) / N
119
+
120
+ results .append ({
121
+ 'i' : 50 ,
122
+ 'model' : 'pretrained' ,
123
+ 'elbo' : elbo ,
124
+ 'log_p_v' : logv ,
125
+ 'log_p_w' : logp ,
126
+ 'kl' : logp - elbo
127
+ })
128
+
129
+ for p in elbo_params :
130
+ svi .model .load_state_dict (torch .load (p ))
131
+ with torch .no_grad ():
132
+ logv = svi .model ._prior .log_prob (z_test [0 ].to (torch .device ('cuda' ))).mean ().item ()
133
+ elbo = svi .score_batch (test_data , num_samples = 100 ) / N
134
+ logp = svi .score_batch (test_data , num_samples = 100 , log_prob = True ) / N
135
+
136
+ results .append ({
137
+ 'i' : 50 ,
138
+ 'model' : 'svi_elbo' ,
139
+ 'elbo' : elbo ,
140
+ 'log_p_v' : logv ,
141
+ 'log_p_w' : logp ,
142
+ 'kl' : logp - elbo
143
+ })
144
+
145
+ for p in iw_params :
146
+ svi .model .load_state_dict (torch .load (p ))
147
+ with torch .no_grad ():
148
+ logv = svi .model ._prior .log_prob (z_test [0 ].to (torch .device ('cuda' ))).mean ().item ()
149
+ elbo = svi .score_batch (test_data , num_samples = 100 ) / N
150
+ logp = svi .score_batch (test_data , num_samples = 100 , log_prob = True ) / N
151
+
152
+ results .append ({
153
+ 'i' : 50 ,
154
+ 'model' : 'svi_iw' ,
155
+ 'elbo' : elbo ,
156
+ 'log_p_v' : logv ,
157
+ 'log_p_w' : logp ,
158
+ 'kl' : logp - elbo
159
+ })
160
+
161
+ df = pd .DataFrame (results )
162
+
163
+
164
+
165
+
0 commit comments