Skip to content

Commit 885067e

Browse files
committed
WIP update.
1 parent 096d74c commit 885067e

20 files changed

+1185
-432
lines changed

analysis/analyse_mixture_compare.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import argparse
2+
import collections
3+
import os
4+
5+
import numpy as np
6+
import torch
7+
import torch.utils.data as data_utils
8+
9+
import pandas as pd
10+
11+
from deconv.gmm.sgd_deconv_gmm import SGDDeconvGMM
12+
from deconv.gmm.sgd_gmm import SGDGMM
13+
from deconv.gmm.data import DeconvDataset
14+
from deconv.flow.svi import SVIFlow
15+
from deconv.utils.data_gen import generate_mixture_data
16+
17+
parser = argparse.ArgumentParser(description='Process SVI on GMM results.')
18+
19+
parser.add_argument('gmm_results_dir')
20+
parser.add_argument('elbo_results_dir')
21+
parser.add_argument('iw_results_dir')
22+
23+
args = parser.parse_args()
24+
25+
K = 4
26+
D = 2
27+
N = 50000
28+
29+
torch.set_default_tensor_type(torch.FloatTensor)
30+
31+
_, S, _, _, (z_test, x_test) = generate_mixture_data()
32+
test_data = DeconvDataset(x_test.squeeze(), S.repeat(N, 1, 1))
33+
34+
gmm_params = []
35+
36+
for f in os.listdir(args.gmm_results_dir):
37+
path = os.path.join(args.gmm_results_dir, f)
38+
gmm_params.append(path)
39+
40+
elbo_params = collections.defaultdict(list)
41+
42+
for f in os.listdir(args.elbo_results_dir):
43+
path = os.path.join(args.elbo_results_dir, f)
44+
elbo_params[int(f[16:18])].append(path)
45+
46+
iw_params = collections.defaultdict(list)
47+
48+
for f in os.listdir(args.iw_results_dir):
49+
path = os.path.join(args.iw_results_dir, f)
50+
iw_params[int(f[16:18])].append(path)
51+
52+
gmm = SGDDeconvGMM(
53+
K,
54+
D,
55+
batch_size=200,
56+
device=torch.device('cuda')
57+
)
58+
test_gmm = SGDGMM(
59+
K,
60+
D,
61+
batch_size=200,
62+
device=torch.device('cuda')
63+
)
64+
65+
svi = SVIFlow(
66+
2,
67+
5,
68+
device=torch.device('cuda'),
69+
batch_size=512,
70+
epochs=100,
71+
lr=1e-4,
72+
n_samples=50,
73+
use_iwae=False,
74+
context_size=64,
75+
hidden_features=128
76+
)
77+
78+
results = []
79+
80+
for p in gmm_params:
81+
gmm.module.load_state_dict(torch.load(p))
82+
test_gmm.module.load_state_dict(torch.load(p))
83+
with torch.no_grad():
84+
logv = test_gmm.module([z_test[0].to(torch.device('cuda'))]).mean().item()
85+
logp = gmm.score_batch(test_data) / N
86+
results.append({
87+
'i': 0,
88+
'model': 'gmm',
89+
'elbo': None,
90+
'log_p_v': logv,
91+
'log_p_w': logp,
92+
'kl': None
93+
})
94+
95+
test_data = DeconvDataset(x_test.squeeze(), torch.cholesky(S.repeat(N, 1, 1)))
96+
97+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
98+
99+
param_sets = {
100+
'svi_elbo': elbo_params,
101+
'svi_iw': iw_params
102+
}
103+
104+
for k, params in param_sets.items():
105+
print('Processing {}'.format(k))
106+
for i in (1, 10, 25, 50):
107+
print('Processing K: {}'.format(i))
108+
for p in params[i]:
109+
svi.model.load_state_dict(torch.load(p))
110+
with torch.no_grad():
111+
logv = svi.model._prior.log_prob(z_test[0].to(torch.device('cuda'))).mean().item()
112+
elbo = svi.score_batch(test_data, num_samples=100) / N
113+
logp = svi.score_batch(test_data, num_samples=100, log_prob=True) / N
114+
115+
results.append({
116+
'i': i,
117+
'model': k,
118+
'elbo': elbo,
119+
'log_p_v': logv,
120+
'log_p_w': logp,
121+
'kl': logp - elbo
122+
})
123+
124+
df = pd.DataFrame(results)
125+
126+
127+
128+

analysis/analyse_pretraining.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import argparse
2+
import collections
3+
import os
4+
5+
import numpy as np
6+
import torch
7+
import torch.utils.data as data_utils
8+
9+
import pandas as pd
10+
11+
from deconv.gmm.sgd_deconv_gmm import SGDDeconvGMM
12+
from deconv.gmm.sgd_gmm import SGDGMM
13+
from deconv.gmm.data import DeconvDataset
14+
from deconv.flow.svi import SVIFlow
15+
16+
parser = argparse.ArgumentParser(description='Process SVI on GMM results.')
17+
18+
parser.add_argument('pretrain_results_dir')
19+
parser.add_argument('elbo_results_dir')
20+
parser.add_argument('iw_results_dir')
21+
22+
args = parser.parse_args()
23+
24+
K = 4
25+
D = 2
26+
N = 50000
27+
28+
torch.set_default_tensor_type(torch.FloatTensor)
29+
30+
ref_gmm = SGDDeconvGMM(
31+
K,
32+
D,
33+
batch_size=512,
34+
device=torch.device('cpu')
35+
)
36+
37+
ref_gmm.module.soft_weights.data = torch.zeros(K)
38+
scale = 2
39+
40+
ref_gmm.module.means.data = torch.Tensor([
41+
[-scale, 0],
42+
[scale, 0],
43+
[0, -scale],
44+
[0, scale]
45+
])
46+
47+
short_std = 0.3
48+
long_std = 1
49+
50+
stds = torch.Tensor([
51+
[short_std, long_std],
52+
[short_std, long_std],
53+
[long_std, short_std],
54+
[long_std, short_std]
55+
])
56+
57+
ref_gmm.module.l_diag.data = torch.log(stds)
58+
59+
torch.manual_seed(263568)
60+
61+
z_test = ref_gmm.sample_prior(N)
62+
63+
noise_short = 0.1
64+
noise_long = 1.0
65+
66+
S = torch.Tensor([
67+
[noise_short, 0],
68+
[0, noise_long]
69+
])
70+
71+
noise_distribution = torch.distributions.MultivariateNormal(
72+
loc=torch.Tensor([0, 0]),
73+
covariance_matrix=S
74+
)
75+
76+
x_test = z_test + noise_distribution.sample([N])
77+
test_data = DeconvDataset(x_test.squeeze(), S.repeat(N, 1, 1))
78+
79+
pretrained_params = []
80+
for f in os.listdir(args.pretrain_results_dir):
81+
path = os.path.join(args.pretrain_results_dir, f)
82+
pretrained_params.append(path)
83+
84+
elbo_params = []
85+
for f in os.listdir(args.elbo_results_dir):
86+
path = os.path.join(args.elbo_results_dir, f)
87+
elbo_params.append(path)
88+
89+
iw_params = []
90+
for f in os.listdir(args.iw_results_dir):
91+
path = os.path.join(args.iw_results_dir, f)
92+
iw_params.append(path)
93+
94+
svi = SVIFlow(
95+
2,
96+
5,
97+
device=torch.device('cuda'),
98+
batch_size=512,
99+
epochs=100,
100+
lr=1e-4,
101+
n_samples=50,
102+
use_iwae=False,
103+
context_size=64,
104+
hidden_features=128
105+
)
106+
107+
results = []
108+
109+
test_data = DeconvDataset(x_test.squeeze(), torch.cholesky(S.repeat(N, 1, 1)))
110+
111+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
112+
113+
for p in pretrained_params:
114+
svi.model.load_state_dict(torch.load(p))
115+
with torch.no_grad():
116+
logv = svi.model._prior.log_prob(z_test[0].to(torch.device('cuda'))).mean().item()
117+
elbo = svi.score_batch(test_data, num_samples=100) / N
118+
logp = svi.score_batch(test_data, num_samples=100, log_prob=True) / N
119+
120+
results.append({
121+
'i': 50,
122+
'model': 'pretrained',
123+
'elbo': elbo,
124+
'log_p_v': logv,
125+
'log_p_w': logp,
126+
'kl': logp - elbo
127+
})
128+
129+
for p in elbo_params:
130+
svi.model.load_state_dict(torch.load(p))
131+
with torch.no_grad():
132+
logv = svi.model._prior.log_prob(z_test[0].to(torch.device('cuda'))).mean().item()
133+
elbo = svi.score_batch(test_data, num_samples=100) / N
134+
logp = svi.score_batch(test_data, num_samples=100, log_prob=True) / N
135+
136+
results.append({
137+
'i': 50,
138+
'model': 'svi_elbo',
139+
'elbo': elbo,
140+
'log_p_v': logv,
141+
'log_p_w': logp,
142+
'kl': logp - elbo
143+
})
144+
145+
for p in iw_params:
146+
svi.model.load_state_dict(torch.load(p))
147+
with torch.no_grad():
148+
logv = svi.model._prior.log_prob(z_test[0].to(torch.device('cuda'))).mean().item()
149+
elbo = svi.score_batch(test_data, num_samples=100) / N
150+
logp = svi.score_batch(test_data, num_samples=100, log_prob=True) / N
151+
152+
results.append({
153+
'i': 50,
154+
'model': 'svi_iw',
155+
'elbo': elbo,
156+
'log_p_v': logv,
157+
'log_p_w': logp,
158+
'kl': logp - elbo
159+
})
160+
161+
df = pd.DataFrame(results)
162+
163+
164+
165+

0 commit comments

Comments
 (0)