Skip to content

Commit 79115d7

Browse files
committed
PyMC3 simulation of MNIST56
1 parent b0fe2f0 commit 79115d7

File tree

4 files changed

+59
-1
lines changed

4 files changed

+59
-1
lines changed
File renamed without changes.

MCMC/mnist56_pymc.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pymc3 as pm
2+
import theano.tensor as tt
3+
import pandas as pd
4+
import numpy as np
5+
6+
mnist56_url = {
7+
"train": "https://www.dropbox.com/s/l7uppxi1wvfj45z/MNIST56_train.csv?dl=1",
8+
"test": "https://www.dropbox.com/s/399gkdk9bhqvz86/MNIST56_test.csv?dl=1"
9+
}
10+
11+
12+
def prepare_data(fold_name="train"):
13+
"""
14+
Processes MNIST56 dataset.
15+
:param fold_name: either 'train' or 'test'
16+
:return: binary pixels matrix of size (n, 25) and corresponding labels (vector of size n)
17+
"""
18+
dataframe = pd.read_csv(mnist56_url[fold_name])
19+
x = dataframe.iloc[:, :25] # MNIST 5x5 flatten
20+
x = (x > 0).astype(int) # binarize pixels
21+
y = dataframe.iloc[:, 25] # labels 0 or 1
22+
return x, y
23+
24+
25+
def predict(x_data, w):
26+
"""
27+
:param x_data: matrix of size (n, 25)
28+
:param w: matrix of size (25, 2)
29+
:return: vector of size n of predicted labels
30+
"""
31+
logit_vec = np.dot(x_data, w)
32+
y_pred = logit_vec[:, 1] > logit_vec[:, 0]
33+
y_pred = y_pred.astype(int)
34+
return y_pred
35+
36+
37+
def main():
38+
np.random.seed(113)
39+
x_train, y_train = prepare_data(fold_name="train")
40+
model = pm.Model()
41+
42+
with model:
43+
w = pm.Bernoulli('w', p=0.5, shape=(25, 2))
44+
logit_vec = tt.dot(x_train, w)
45+
logit_p = logit_vec[:, 1] - logit_vec[:, 0] # logit of p(y=1)
46+
y_obs = pm.Bernoulli('y_obs', logit_p=logit_p, observed=y_train)
47+
trace = pm.sample(draws=10, njobs=1, chains=1, n_init=1000, tune=0)
48+
w_mean = trace.get_values('w').mean(axis=0)
49+
w_binary = (w_mean > 0.5).astype(int)
50+
x_test, y_test = prepare_data(fold_name="test")
51+
y_pred = predict(x_data=x_test, w=w_binary)
52+
accuracy = (y_pred == y_test).sum() / len(y_test)
53+
print("Test accuracy: {}".format(accuracy))
54+
55+
56+
if __name__ == '__main__':
57+
main()
File renamed without changes.

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,5 @@ trainer.train(n_epoch=100)
5454

5555
* Train plots. Navigate to [http://ec2-34-227-113-244.compute-1.amazonaws.com:8099](http://ec2-34-227-113-244.compute-1.amazonaws.com:8099) and choose the Environment you want (`main` env is empty).
5656
* For your local results, go to [http://localhost:8097](http://localhost:8097)
57-
* JAGS simulation in _R_: [paper](JAGS/paper.pdf), [code](JAGS/mcmc_jags.R)
57+
* JAGS simulation in _R_: [paper](MCMC/paper.pdf), [source](MCMC/mnist56_jags.R)
58+
* PyMC3 simulation in Python: [source](MCMC/mnist56_pymc.py)

0 commit comments

Comments
 (0)