Skip to content

Commit 27aaf24

Browse files
author
Yasmin Mzayek
committed
inference script skeleton
1 parent e538e74 commit 27aaf24

File tree

1 file changed

+15
-56
lines changed

1 file changed

+15
-56
lines changed

inference.py

Lines changed: 15 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,80 +8,39 @@
88

99

1010
def main():
11-
model_path = "/data/p301081/astronomy/Models"
12-
13-
filepathTest = "/Trainingdata/20pix_centered_test.csv"
14-
print("Loading test data from: ", filepathTest)
15-
testSetX, testSetY = cu.load_data_from_csv(filepathTest, shuffle=False, only_positive=False, magRange=[20, 26])
11+
12+
image_path = "/data/pg-ds_cit/Projects/Astronomy/AstronomyProject/Images"
13+
print("Loading test data from: ", image_path)
14+
testSet, X_test, Y_test = cu.load_data_from_images(image_path, 'test')
1615

1716
# Make sure data is float32 to have enough decimals after normalization
1817
X_test = testSetX.astype('float32')
1918
# Normalize pixel values between 0 and 1
20-
X_train /= 2**8
2119
X_test /= 2**8
2220

23-
# If subtract pixel mean is enabled
24-
subtract_pixel_mean = False
25-
if subtract_pixel_mean:
26-
X_train_mean = np.mean(X_train, axis=0)
27-
X_train -= X_train_mean
28-
X_test -= X_train_mean
29-
30-
Y_train = trainSetY[:, 0:5]
31-
Y_test = testSetY[:, 0:5]
32-
33-
Y_trainLabels = trainSetY[:, 0]
34-
Y_testLabels = testSetY[:, 0]
35-
3621
# input image dimensions
37-
img_rows, img_cols = X_train.shape[1:3]
22+
img_rows, img_cols = X_test.shape[1:3]
3823

3924
# Convert to correct Keras format
40-
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
4125
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
4226

4327
print()
44-
print('Data loaded: train:', len(X_train), 'test:', len(X_test))
45-
print('trainSetX:', trainSetX.shape)
46-
print('X_train:', X_train.shape)
47-
print('trainSetY:', trainSetY.shape)
48-
print('Y_train:', Y_train.shape)
28+
print('Data loaded: test:', len(X_test))
29+
print('X_test:', X_test.shape)
4930

50-
if resnet:
51-
model = load_model(modelName, custom_objects={'custom_YOLO_loss': cu.custom_YOLO_loss,
52-
'f1_metric': cu.f1_metric})
53-
else:
54-
model = load_model(modelName, custom_objects={'custom_loss': cu.custom_YOLO_loss})
31+
model = load_model(model_path)
5532

5633
print(model.summary())
5734

58-
# Evaluate
59-
scoresTrain = model.evaluate(X_train, [Y_train, Y_trainLabels], verbose=2)
60-
scoresTest = model.evaluate(X_test, [Y_test, Y_testLabels], verbose=2)
61-
print(scoresTrain)
62-
print(scoresTest)
63-
6435
# Make predictions
65-
predictionsTrain = model.predict(X_train)
6636
predictionsTest = model.predict(X_test)
67-
68-
# ResNet predictions array is different shape that simple CNN
69-
if resnet:
70-
predictionsTrain = predictionsTrain[0]
71-
predictionsTest = predictionsTest[0]
72-
73-
print("\nTraining set:")
74-
cu.analyze_5unit_errors(predictionsTrain, Y_train)
75-
print("\nTest set:")
76-
cu.analyze_5unit_errors(predictionsTest, Y_test)
77-
78-
dataframe2Dtrain = cu.create_histogram_2d(predictionsTrain, trainSetY, binsMag, binsLength)
79-
cu.plot_results_heatmap(dataframe2Dtrain, binsMag, fig_name = 'Plot_2D_histogram_CNN_train.pdf', title="Deep learning training set completeness") #ym
80-
81-
dataframe2Dtest = cu.create_histogram_2d(predictionsTest, testSetY, binsMag, binsLength)
82-
cu.plot_results_heatmap(dataframe2Dtest, binsMag, fig_name = 'Plot_2D_histogram_CNN_test.pdf', title="Deep learning test set completeness") #ym
83-
84-
plt.show()
37+
predictionsTest = [round(pred[0]) for pred in predictionsTest]
38+
39+
# Evaluate
40+
if Y_test:
41+
scoresTest = model.evaluate(X_test, Y_test, verbose=2)
42+
print(scoresTest)
43+
test_set_metrics = cu.analyze_5unit_errors(predictionsTest, Y_test)
8544

8645

8746
if __name__ == "__main__":

0 commit comments

Comments
 (0)