|
8 | 8 |
|
9 | 9 |
|
10 | 10 | def main():
|
11 |
| - model_path = "/data/p301081/astronomy/Models" |
12 |
| - |
13 |
| - filepathTest = "/Trainingdata/20pix_centered_test.csv" |
14 |
| - print("Loading test data from: ", filepathTest) |
15 |
| - testSetX, testSetY = cu.load_data_from_csv(filepathTest, shuffle=False, only_positive=False, magRange=[20, 26]) |
| 11 | + |
| 12 | + image_path = "/data/pg-ds_cit/Projects/Astronomy/AstronomyProject/Images" |
| 13 | + print("Loading test data from: ", image_path) |
| 14 | + testSet, X_test, Y_test = cu.load_data_from_images(image_path, 'test') |
16 | 15 |
|
17 | 16 | # Make sure data is float32 to have enough decimals after normalization
|
18 | 17 | X_test = testSetX.astype('float32')
|
19 | 18 | # Normalize pixel values between 0 and 1
|
20 |
| - X_train /= 2**8 |
21 | 19 | X_test /= 2**8
|
22 | 20 |
|
23 |
| - # If subtract pixel mean is enabled |
24 |
| - subtract_pixel_mean = False |
25 |
| - if subtract_pixel_mean: |
26 |
| - X_train_mean = np.mean(X_train, axis=0) |
27 |
| - X_train -= X_train_mean |
28 |
| - X_test -= X_train_mean |
29 |
| - |
30 |
| - Y_train = trainSetY[:, 0:5] |
31 |
| - Y_test = testSetY[:, 0:5] |
32 |
| - |
33 |
| - Y_trainLabels = trainSetY[:, 0] |
34 |
| - Y_testLabels = testSetY[:, 0] |
35 |
| - |
36 | 21 | # input image dimensions
|
37 |
| - img_rows, img_cols = X_train.shape[1:3] |
| 22 | + img_rows, img_cols = X_test.shape[1:3] |
38 | 23 |
|
39 | 24 | # Convert to correct Keras format
|
40 |
| - X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1) |
41 | 25 | X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
|
42 | 26 |
|
43 | 27 | print()
|
44 |
| - print('Data loaded: train:', len(X_train), 'test:', len(X_test)) |
45 |
| - print('trainSetX:', trainSetX.shape) |
46 |
| - print('X_train:', X_train.shape) |
47 |
| - print('trainSetY:', trainSetY.shape) |
48 |
| - print('Y_train:', Y_train.shape) |
| 28 | + print('Data loaded: test:', len(X_test)) |
| 29 | + print('X_test:', X_test.shape) |
49 | 30 |
|
50 |
| - if resnet: |
51 |
| - model = load_model(modelName, custom_objects={'custom_YOLO_loss': cu.custom_YOLO_loss, |
52 |
| - 'f1_metric': cu.f1_metric}) |
53 |
| - else: |
54 |
| - model = load_model(modelName, custom_objects={'custom_loss': cu.custom_YOLO_loss}) |
| 31 | + model = load_model(model_path) |
55 | 32 |
|
56 | 33 | print(model.summary())
|
57 | 34 |
|
58 |
| - # Evaluate |
59 |
| - scoresTrain = model.evaluate(X_train, [Y_train, Y_trainLabels], verbose=2) |
60 |
| - scoresTest = model.evaluate(X_test, [Y_test, Y_testLabels], verbose=2) |
61 |
| - print(scoresTrain) |
62 |
| - print(scoresTest) |
63 |
| - |
64 | 35 | # Make predictions
|
65 |
| - predictionsTrain = model.predict(X_train) |
66 | 36 | predictionsTest = model.predict(X_test)
|
67 |
| - |
68 |
| - # ResNet predictions array is different shape that simple CNN |
69 |
| - if resnet: |
70 |
| - predictionsTrain = predictionsTrain[0] |
71 |
| - predictionsTest = predictionsTest[0] |
72 |
| - |
73 |
| - print("\nTraining set:") |
74 |
| - cu.analyze_5unit_errors(predictionsTrain, Y_train) |
75 |
| - print("\nTest set:") |
76 |
| - cu.analyze_5unit_errors(predictionsTest, Y_test) |
77 |
| - |
78 |
| - dataframe2Dtrain = cu.create_histogram_2d(predictionsTrain, trainSetY, binsMag, binsLength) |
79 |
| - cu.plot_results_heatmap(dataframe2Dtrain, binsMag, fig_name = 'Plot_2D_histogram_CNN_train.pdf', title="Deep learning training set completeness") #ym |
80 |
| - |
81 |
| - dataframe2Dtest = cu.create_histogram_2d(predictionsTest, testSetY, binsMag, binsLength) |
82 |
| - cu.plot_results_heatmap(dataframe2Dtest, binsMag, fig_name = 'Plot_2D_histogram_CNN_test.pdf', title="Deep learning test set completeness") #ym |
83 |
| - |
84 |
| - plt.show() |
| 37 | + predictionsTest = [round(pred[0]) for pred in predictionsTest] |
| 38 | + |
| 39 | + # Evaluate |
| 40 | + if Y_test: |
| 41 | + scoresTest = model.evaluate(X_test, Y_test, verbose=2) |
| 42 | + print(scoresTest) |
| 43 | + test_set_metrics = cu.analyze_5unit_errors(predictionsTest, Y_test) |
85 | 44 |
|
86 | 45 |
|
87 | 46 | if __name__ == "__main__":
|
|
0 commit comments