|
| 1 | +import pymc3 as pm |
| 2 | +import theano.tensor as tt |
| 3 | +import pandas as pd |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +mnist56_url = { |
| 7 | + "train": "https://www.dropbox.com/s/l7uppxi1wvfj45z/MNIST56_train.csv?dl=1", |
| 8 | + "test": "https://www.dropbox.com/s/399gkdk9bhqvz86/MNIST56_test.csv?dl=1" |
| 9 | +} |
| 10 | + |
| 11 | + |
| 12 | +def prepare_data(fold_name="train"): |
| 13 | + """ |
| 14 | + Processes MNIST56 dataset. |
| 15 | + :param fold_name: either 'train' or 'test' |
| 16 | + :return: binary pixels matrix of size (n, 25) and corresponding labels (vector of size n) |
| 17 | + """ |
| 18 | + dataframe = pd.read_csv(mnist56_url[fold_name]) |
| 19 | + x = dataframe.iloc[:, :25] # MNIST 5x5 flatten |
| 20 | + x = (x > 0).astype(int) # binarize pixels |
| 21 | + y = dataframe.iloc[:, 25] # labels 0 or 1 |
| 22 | + return x, y |
| 23 | + |
| 24 | + |
| 25 | +def predict(x_data, w): |
| 26 | + """ |
| 27 | + :param x_data: matrix of size (n, 25) |
| 28 | + :param w: matrix of size (25, 2) |
| 29 | + :return: vector of size n of predicted labels |
| 30 | + """ |
| 31 | + logit_vec = np.dot(x_data, w) |
| 32 | + y_pred = logit_vec[:, 1] > logit_vec[:, 0] |
| 33 | + y_pred = y_pred.astype(int) |
| 34 | + return y_pred |
| 35 | + |
| 36 | + |
| 37 | +def main(): |
| 38 | + np.random.seed(113) |
| 39 | + x_train, y_train = prepare_data(fold_name="train") |
| 40 | + model = pm.Model() |
| 41 | + |
| 42 | + with model: |
| 43 | + w = pm.Bernoulli('w', p=0.5, shape=(25, 2)) |
| 44 | + logit_vec = tt.dot(x_train, w) |
| 45 | + logit_p = logit_vec[:, 1] - logit_vec[:, 0] # logit of p(y=1) |
| 46 | + y_obs = pm.Bernoulli('y_obs', logit_p=logit_p, observed=y_train) |
| 47 | + trace = pm.sample(draws=10, njobs=1, chains=1, n_init=1000, tune=0) |
| 48 | + w_mean = trace.get_values('w').mean(axis=0) |
| 49 | + w_binary = (w_mean > 0.5).astype(int) |
| 50 | + x_test, y_test = prepare_data(fold_name="test") |
| 51 | + y_pred = predict(x_data=x_test, w=w_binary) |
| 52 | + accuracy = (y_pred == y_test).sum() / len(y_test) |
| 53 | + print("Test accuracy: {}".format(accuracy)) |
| 54 | + |
| 55 | + |
| 56 | +if __name__ == '__main__': |
| 57 | + main() |
0 commit comments