Skip to content

Commit efffdbe

Browse files
committed
Update the code base for 21th century
1 parent f50ca16 commit efffdbe

File tree

6 files changed

+1470
-659
lines changed

6 files changed

+1470
-659
lines changed

grid_search.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
import torch
3+
from matplotlib import pyplot as plt
4+
from sklearn.datasets import load_digits
5+
from sklearn.model_selection import train_test_split
6+
from tqdm import tqdm
7+
8+
from som import SOM
9+
10+
11+
def quantization_error(som, data):
12+
_, distances = som.best_match(data)
13+
return torch.mean(torch.min(distances, dim=0)[0])
14+
15+
16+
def grid_search_som(data, unit_range, epochs=1000, alpha_max=0.05, trials=3):
17+
results = []
18+
19+
for num_units in tqdm(unit_range, desc="Grid Search"):
20+
trial_errors = []
21+
for _ in range(trials):
22+
som = SOM(data, num_units=num_units, alpha_max=alpha_max)
23+
som.train_batch(num_epoch=epochs, verbose=False)
24+
error = quantization_error(som, data)
25+
trial_errors.append(error.item())
26+
27+
avg_error = np.mean(trial_errors)
28+
std_error = np.std(trial_errors)
29+
results.append((num_units, avg_error, std_error))
30+
31+
print(
32+
f"Units: {num_units}, Avg Error: {avg_error:.4f}, Std Error: {std_error:.4f}"
33+
)
34+
35+
return results
36+
37+
38+
def find_elbow(x, y):
39+
# Normalize the data
40+
x = np.array(x)
41+
y = np.array(y)
42+
x_norm = (x - min(x)) / (max(x) - min(x))
43+
y_norm = (y - min(y)) / (max(y) - min(y))
44+
45+
# Calculate the distances from each point to the line connecting the first and last points
46+
coords = np.vstack([x_norm, y_norm]).T
47+
first = coords[0]
48+
line_vec = coords[-1] - coords[0]
49+
line_vec_norm = line_vec / np.sqrt(np.sum(line_vec**2))
50+
vec_from_first = coords - first
51+
scalar_proj = np.dot(vec_from_first, line_vec_norm)
52+
proj = np.outer(scalar_proj, line_vec_norm)
53+
distances = np.sqrt(np.sum((vec_from_first - proj) ** 2, axis=1))
54+
55+
# Find the elbow point (maximum distance)
56+
elbow_index = np.argmax(distances)
57+
return x[elbow_index], y[elbow_index]
58+
59+
60+
if __name__ == "__main__":
61+
# Load Digits dataset
62+
digits = load_digits()
63+
data = torch.from_numpy(digits.data).float()
64+
65+
# Normalize the data
66+
data = (data - data.min()) / (data.max() - data.min())
67+
68+
# Split the data into train and test sets
69+
X_train, X_test = train_test_split(data, test_size=0.2, random_state=42)
70+
71+
# Define the range of units to search
72+
unit_range = [9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196]
73+
74+
# Perform grid search
75+
results = grid_search_som(
76+
X_train, unit_range, epochs=1000, alpha_max=0.05, trials=3
77+
)
78+
79+
# Extract units and errors
80+
units = [r[0] for r in results]
81+
errors = [r[1] for r in results]
82+
error_stds = [r[2] for r in results]
83+
84+
# Find the elbow point
85+
elbow_units, elbow_error = find_elbow(units, errors)
86+
87+
print(f"\nElbow point: {elbow_units:.0f} units, Error: {elbow_error:.4f}")
88+
89+
# Plot the results
90+
plt.figure(figsize=(10, 6))
91+
plt.errorbar(units, errors, yerr=error_stds, fmt="o-", capsize=5)
92+
plt.plot(elbow_units, elbow_error, "ro", markersize=10, label="Elbow point")
93+
plt.xlabel("Number of Units")
94+
plt.ylabel("Quantization Error")
95+
plt.title("SOM Grid Search Results")
96+
plt.xscale("log")
97+
plt.grid(True)
98+
plt.legend()
99+
plt.show()
100+
101+
# Train the SOM with the elbow point number of units
102+
best_som = SOM(data, num_units=int(elbow_units), alpha_max=0.05)
103+
best_som.train_batch(num_epoch=1000, verbose=True)
104+
105+
# Evaluate on test set
106+
test_error = quantization_error(best_som, X_test)
107+
print(f"\nTest set quantization error: {test_error:.4f}")

sample_run.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import os
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import torch
6+
from PIL import Image
7+
from sklearn.datasets import load_digits
8+
from sklearn.decomposition import PCA
9+
10+
from som import SOM
11+
12+
13+
def get_node_coordinates(som, pca):
14+
coords = []
15+
for i in range(som.height):
16+
for j in range(som.width):
17+
node_index = i * som.width + j
18+
node_weights = som.W[node_index].detach().numpy()
19+
coord = pca.transform([node_weights])[0]
20+
coords.append(coord)
21+
return np.array(coords)
22+
23+
24+
# Load Iris dataset
25+
data = load_digits().data
26+
data = torch.from_numpy(data).float()
27+
print(data.shape)
28+
29+
# Initialize SOM
30+
som = SOM(data, alpha_max=0.05, num_units=49)
31+
32+
# Train SOM
33+
som.train_batch(num_epoch=1000, verbose=True)
34+
35+
# Get salient instances and units
36+
salient_insts = som.salient_insts()
37+
salient_units = som.salient_units()
38+
39+
# Perform PCA to reduce data to 2D for visualization
40+
pca = PCA(n_components=2)
41+
data_2d = pca.fit_transform(som.X.numpy())
42+
units_2d = pca.transform(som.W.detach().numpy())
43+
44+
# Get node coordinates
45+
node_coords = get_node_coordinates(som, pca)
46+
47+
# Create a plot
48+
plt.figure(figsize=(12, 8))
49+
50+
# Plot data points
51+
salient_mask = som.inst_saliency.numpy()
52+
plt.scatter(
53+
data_2d[salient_mask, 0],
54+
data_2d[salient_mask, 1],
55+
c=som.ins_unit_assign[salient_mask],
56+
cmap="viridis",
57+
alpha=0.6,
58+
label="Salient Samples",
59+
)
60+
plt.scatter(
61+
data_2d[~salient_mask, 0],
62+
data_2d[~salient_mask, 1],
63+
c="red",
64+
marker="x",
65+
alpha=0.6,
66+
label="Outlier Samples",
67+
)
68+
69+
# Plot SOM units
70+
salient_units_mask = som.unit_saliency.numpy()
71+
plt.scatter(
72+
node_coords[salient_units_mask, 0],
73+
node_coords[salient_units_mask, 1],
74+
c="black",
75+
marker="s",
76+
s=50,
77+
label="Salient Units",
78+
)
79+
plt.scatter(
80+
node_coords[~salient_units_mask, 0],
81+
node_coords[~salient_units_mask, 1],
82+
c="red",
83+
marker="s",
84+
s=50,
85+
label="Outlier Units",
86+
)
87+
88+
# Draw lattice lines
89+
for i in range(som.height):
90+
for j in range(som.width):
91+
node_index = i * som.width + j
92+
if j < som.width - 1: # Horizontal line
93+
next_node_index = node_index + 1
94+
plt.plot(
95+
[node_coords[node_index, 0], node_coords[next_node_index, 0]],
96+
[node_coords[node_index, 1], node_coords[next_node_index, 1]],
97+
"gray",
98+
alpha=0.5,
99+
)
100+
if i < som.height - 1: # Vertical line
101+
next_node_index = node_index + som.width
102+
plt.plot(
103+
[node_coords[node_index, 0], node_coords[next_node_index, 0]],
104+
[node_coords[node_index, 1], node_coords[next_node_index, 1]],
105+
"gray",
106+
alpha=0.5,
107+
)
108+
109+
# Add labels and title
110+
plt.xlabel("First Principal Component")
111+
plt.ylabel("Second Principal Component")
112+
plt.title("SOM Units and Data Samples with Outliers and Lattice")
113+
plt.legend()
114+
115+
# Show the plot
116+
plt.show()
117+
118+
# Optional: Print some statistics
119+
print(f"Number of salient samples: {salient_mask.sum()}")
120+
print(f"Number of outlier samples: {(~salient_mask).sum()}")
121+
print(f"Number of salient units: {salient_units_mask.sum()}")
122+
print(f"Number of outlier units: {(~salient_units_mask).sum()}")
123+
124+
# Create a new figure for the perfect 2D lattice plot
125+
plt.figure(figsize=(12, 12))
126+
127+
# Create a perfect 2D grid for SOM nodes
128+
grid_x, grid_y = np.meshgrid(np.arange(som.width), np.arange(som.height))
129+
grid_x = grid_x.flatten()
130+
grid_y = grid_y.flatten()
131+
132+
# Plot the perfect grid
133+
plt.scatter(grid_x, grid_y, c="lightgray", s=200, marker="s")
134+
135+
# Draw grid lines
136+
for x in range(som.width):
137+
plt.axvline(x, color="lightgray", linestyle="--")
138+
for y in range(som.height):
139+
plt.axhline(y, color="lightgray", linestyle="--")
140+
141+
# Get the unit assignments for each sample
142+
unit_assignments = som.ins_unit_assign.numpy()
143+
144+
# Calculate the positions of samples on the grid
145+
sample_x = grid_x[unit_assignments].astype(float)
146+
sample_y = grid_y[unit_assignments].astype(float)
147+
148+
# Add some jitter to prevent complete overlap
149+
jitter = 0.2
150+
sample_x += np.random.uniform(-jitter, jitter, sample_x.shape)
151+
sample_y += np.random.uniform(-jitter, jitter, sample_y.shape)
152+
153+
# Plot the samples on the grid
154+
scatter = plt.scatter(
155+
sample_x, sample_y, c=som.ins_unit_assign, cmap="viridis", alpha=0.6
156+
)
157+
158+
# Highlight outlier samples
159+
outlier_mask = ~som.inst_saliency.numpy()
160+
plt.scatter(
161+
sample_x[outlier_mask],
162+
sample_y[outlier_mask],
163+
facecolors="none",
164+
edgecolors="red",
165+
s=50,
166+
linewidths=2,
167+
)
168+
169+
# Highlight outlier units
170+
for unit in np.where(~som.unit_saliency.numpy())[0]:
171+
unit_x, unit_y = som.unit_cords(unit)
172+
plt.gca().add_patch(
173+
plt.Circle((unit_x, unit_y), 0.4, fill=False, edgecolor="red", linewidth=2)
174+
)
175+
176+
# Set labels and title
177+
plt.xlabel("SOM Width")
178+
plt.ylabel("SOM Height")
179+
plt.title("Samples Mapped to Perfect 2D SOM Lattice")
180+
181+
# Set tick labels
182+
plt.xticks(range(som.width))
183+
plt.yticks(range(som.height))
184+
185+
# Add colorbar
186+
cbar = plt.colorbar(scatter)
187+
cbar.set_label("Unit Assignment")
188+
189+
# Adjust plot limits
190+
plt.xlim(-0.5, som.width - 0.5)
191+
plt.ylim(-0.5, som.height - 0.5)
192+
193+
# Show the plot
194+
plt.tight_layout()
195+
plt.show()
196+
197+
# Create a folder to save outlier images
198+
output_folder = "outlier_digits"
199+
os.makedirs(output_folder, exist_ok=True)
200+
201+
# Get the original digit images and their labels
202+
digits = load_digits()
203+
images = digits.images
204+
labels = digits.target
205+
206+
# Find the indices of outlier samples
207+
outlier_indices = np.where(~salient_mask)[0]
208+
209+
# Save outlier images
210+
for i, idx in enumerate(outlier_indices):
211+
img = images[idx]
212+
label = labels[idx]
213+
214+
# Normalize the image to 0-255 range
215+
img_normalized = ((img - img.min()) / (img.max() - img.min()) * 255).astype(
216+
np.uint8
217+
)
218+
219+
# Create a PIL Image
220+
pil_img = Image.fromarray(img_normalized)
221+
222+
# Save the image
223+
filename = f"outlier_{i}_label_{label}.png"
224+
pil_img.save(os.path.join(output_folder, filename))
225+
226+
print(f"Saved {len(outlier_indices)} outlier images to '{output_folder}' folder.")
227+
228+
# Find samples closest to salient units
229+
salient_folder = "salient_digits"
230+
os.makedirs(salient_folder, exist_ok=True)
231+
salient_unit_indices = np.where(som.unit_saliency.numpy())[0]
232+
233+
for i, unit_idx in enumerate(salient_unit_indices):
234+
# Find the sample closest to this salient unit
235+
unit_weights = som.W[unit_idx].detach().numpy()
236+
distances = np.linalg.norm(data.numpy() - unit_weights, axis=1)
237+
closest_sample_idx = np.argmin(distances)
238+
239+
img = images[closest_sample_idx]
240+
label = labels[closest_sample_idx]
241+
242+
# Normalize the image to 0-255 range
243+
img_normalized = ((img - img.min()) / (img.max() - img.min()) * 255).astype(
244+
np.uint8
245+
)
246+
247+
# Create a PIL Image
248+
pil_img = Image.fromarray(img_normalized)
249+
250+
# Save the image
251+
filename = f"salient_unit_{i}_label_{label}.png"
252+
pil_img.save(os.path.join(salient_folder, filename))
253+
254+
print(
255+
f"Saved {len(salient_unit_indices)} salient unit images to '{salient_folder}' folder."
256+
)

0 commit comments

Comments
 (0)