|
| 1 | +import os |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +from PIL import Image |
| 7 | +from sklearn.datasets import load_digits |
| 8 | +from sklearn.decomposition import PCA |
| 9 | + |
| 10 | +from som import SOM |
| 11 | + |
| 12 | + |
| 13 | +def get_node_coordinates(som, pca): |
| 14 | + coords = [] |
| 15 | + for i in range(som.height): |
| 16 | + for j in range(som.width): |
| 17 | + node_index = i * som.width + j |
| 18 | + node_weights = som.W[node_index].detach().numpy() |
| 19 | + coord = pca.transform([node_weights])[0] |
| 20 | + coords.append(coord) |
| 21 | + return np.array(coords) |
| 22 | + |
| 23 | + |
| 24 | +# Load Iris dataset |
| 25 | +data = load_digits().data |
| 26 | +data = torch.from_numpy(data).float() |
| 27 | +print(data.shape) |
| 28 | + |
| 29 | +# Initialize SOM |
| 30 | +som = SOM(data, alpha_max=0.05, num_units=49) |
| 31 | + |
| 32 | +# Train SOM |
| 33 | +som.train_batch(num_epoch=1000, verbose=True) |
| 34 | + |
| 35 | +# Get salient instances and units |
| 36 | +salient_insts = som.salient_insts() |
| 37 | +salient_units = som.salient_units() |
| 38 | + |
| 39 | +# Perform PCA to reduce data to 2D for visualization |
| 40 | +pca = PCA(n_components=2) |
| 41 | +data_2d = pca.fit_transform(som.X.numpy()) |
| 42 | +units_2d = pca.transform(som.W.detach().numpy()) |
| 43 | + |
| 44 | +# Get node coordinates |
| 45 | +node_coords = get_node_coordinates(som, pca) |
| 46 | + |
| 47 | +# Create a plot |
| 48 | +plt.figure(figsize=(12, 8)) |
| 49 | + |
| 50 | +# Plot data points |
| 51 | +salient_mask = som.inst_saliency.numpy() |
| 52 | +plt.scatter( |
| 53 | + data_2d[salient_mask, 0], |
| 54 | + data_2d[salient_mask, 1], |
| 55 | + c=som.ins_unit_assign[salient_mask], |
| 56 | + cmap="viridis", |
| 57 | + alpha=0.6, |
| 58 | + label="Salient Samples", |
| 59 | +) |
| 60 | +plt.scatter( |
| 61 | + data_2d[~salient_mask, 0], |
| 62 | + data_2d[~salient_mask, 1], |
| 63 | + c="red", |
| 64 | + marker="x", |
| 65 | + alpha=0.6, |
| 66 | + label="Outlier Samples", |
| 67 | +) |
| 68 | + |
| 69 | +# Plot SOM units |
| 70 | +salient_units_mask = som.unit_saliency.numpy() |
| 71 | +plt.scatter( |
| 72 | + node_coords[salient_units_mask, 0], |
| 73 | + node_coords[salient_units_mask, 1], |
| 74 | + c="black", |
| 75 | + marker="s", |
| 76 | + s=50, |
| 77 | + label="Salient Units", |
| 78 | +) |
| 79 | +plt.scatter( |
| 80 | + node_coords[~salient_units_mask, 0], |
| 81 | + node_coords[~salient_units_mask, 1], |
| 82 | + c="red", |
| 83 | + marker="s", |
| 84 | + s=50, |
| 85 | + label="Outlier Units", |
| 86 | +) |
| 87 | + |
| 88 | +# Draw lattice lines |
| 89 | +for i in range(som.height): |
| 90 | + for j in range(som.width): |
| 91 | + node_index = i * som.width + j |
| 92 | + if j < som.width - 1: # Horizontal line |
| 93 | + next_node_index = node_index + 1 |
| 94 | + plt.plot( |
| 95 | + [node_coords[node_index, 0], node_coords[next_node_index, 0]], |
| 96 | + [node_coords[node_index, 1], node_coords[next_node_index, 1]], |
| 97 | + "gray", |
| 98 | + alpha=0.5, |
| 99 | + ) |
| 100 | + if i < som.height - 1: # Vertical line |
| 101 | + next_node_index = node_index + som.width |
| 102 | + plt.plot( |
| 103 | + [node_coords[node_index, 0], node_coords[next_node_index, 0]], |
| 104 | + [node_coords[node_index, 1], node_coords[next_node_index, 1]], |
| 105 | + "gray", |
| 106 | + alpha=0.5, |
| 107 | + ) |
| 108 | + |
| 109 | +# Add labels and title |
| 110 | +plt.xlabel("First Principal Component") |
| 111 | +plt.ylabel("Second Principal Component") |
| 112 | +plt.title("SOM Units and Data Samples with Outliers and Lattice") |
| 113 | +plt.legend() |
| 114 | + |
| 115 | +# Show the plot |
| 116 | +plt.show() |
| 117 | + |
| 118 | +# Optional: Print some statistics |
| 119 | +print(f"Number of salient samples: {salient_mask.sum()}") |
| 120 | +print(f"Number of outlier samples: {(~salient_mask).sum()}") |
| 121 | +print(f"Number of salient units: {salient_units_mask.sum()}") |
| 122 | +print(f"Number of outlier units: {(~salient_units_mask).sum()}") |
| 123 | + |
| 124 | +# Create a new figure for the perfect 2D lattice plot |
| 125 | +plt.figure(figsize=(12, 12)) |
| 126 | + |
| 127 | +# Create a perfect 2D grid for SOM nodes |
| 128 | +grid_x, grid_y = np.meshgrid(np.arange(som.width), np.arange(som.height)) |
| 129 | +grid_x = grid_x.flatten() |
| 130 | +grid_y = grid_y.flatten() |
| 131 | + |
| 132 | +# Plot the perfect grid |
| 133 | +plt.scatter(grid_x, grid_y, c="lightgray", s=200, marker="s") |
| 134 | + |
| 135 | +# Draw grid lines |
| 136 | +for x in range(som.width): |
| 137 | + plt.axvline(x, color="lightgray", linestyle="--") |
| 138 | +for y in range(som.height): |
| 139 | + plt.axhline(y, color="lightgray", linestyle="--") |
| 140 | + |
| 141 | +# Get the unit assignments for each sample |
| 142 | +unit_assignments = som.ins_unit_assign.numpy() |
| 143 | + |
| 144 | +# Calculate the positions of samples on the grid |
| 145 | +sample_x = grid_x[unit_assignments].astype(float) |
| 146 | +sample_y = grid_y[unit_assignments].astype(float) |
| 147 | + |
| 148 | +# Add some jitter to prevent complete overlap |
| 149 | +jitter = 0.2 |
| 150 | +sample_x += np.random.uniform(-jitter, jitter, sample_x.shape) |
| 151 | +sample_y += np.random.uniform(-jitter, jitter, sample_y.shape) |
| 152 | + |
| 153 | +# Plot the samples on the grid |
| 154 | +scatter = plt.scatter( |
| 155 | + sample_x, sample_y, c=som.ins_unit_assign, cmap="viridis", alpha=0.6 |
| 156 | +) |
| 157 | + |
| 158 | +# Highlight outlier samples |
| 159 | +outlier_mask = ~som.inst_saliency.numpy() |
| 160 | +plt.scatter( |
| 161 | + sample_x[outlier_mask], |
| 162 | + sample_y[outlier_mask], |
| 163 | + facecolors="none", |
| 164 | + edgecolors="red", |
| 165 | + s=50, |
| 166 | + linewidths=2, |
| 167 | +) |
| 168 | + |
| 169 | +# Highlight outlier units |
| 170 | +for unit in np.where(~som.unit_saliency.numpy())[0]: |
| 171 | + unit_x, unit_y = som.unit_cords(unit) |
| 172 | + plt.gca().add_patch( |
| 173 | + plt.Circle((unit_x, unit_y), 0.4, fill=False, edgecolor="red", linewidth=2) |
| 174 | + ) |
| 175 | + |
| 176 | +# Set labels and title |
| 177 | +plt.xlabel("SOM Width") |
| 178 | +plt.ylabel("SOM Height") |
| 179 | +plt.title("Samples Mapped to Perfect 2D SOM Lattice") |
| 180 | + |
| 181 | +# Set tick labels |
| 182 | +plt.xticks(range(som.width)) |
| 183 | +plt.yticks(range(som.height)) |
| 184 | + |
| 185 | +# Add colorbar |
| 186 | +cbar = plt.colorbar(scatter) |
| 187 | +cbar.set_label("Unit Assignment") |
| 188 | + |
| 189 | +# Adjust plot limits |
| 190 | +plt.xlim(-0.5, som.width - 0.5) |
| 191 | +plt.ylim(-0.5, som.height - 0.5) |
| 192 | + |
| 193 | +# Show the plot |
| 194 | +plt.tight_layout() |
| 195 | +plt.show() |
| 196 | + |
| 197 | +# Create a folder to save outlier images |
| 198 | +output_folder = "outlier_digits" |
| 199 | +os.makedirs(output_folder, exist_ok=True) |
| 200 | + |
| 201 | +# Get the original digit images and their labels |
| 202 | +digits = load_digits() |
| 203 | +images = digits.images |
| 204 | +labels = digits.target |
| 205 | + |
| 206 | +# Find the indices of outlier samples |
| 207 | +outlier_indices = np.where(~salient_mask)[0] |
| 208 | + |
| 209 | +# Save outlier images |
| 210 | +for i, idx in enumerate(outlier_indices): |
| 211 | + img = images[idx] |
| 212 | + label = labels[idx] |
| 213 | + |
| 214 | + # Normalize the image to 0-255 range |
| 215 | + img_normalized = ((img - img.min()) / (img.max() - img.min()) * 255).astype( |
| 216 | + np.uint8 |
| 217 | + ) |
| 218 | + |
| 219 | + # Create a PIL Image |
| 220 | + pil_img = Image.fromarray(img_normalized) |
| 221 | + |
| 222 | + # Save the image |
| 223 | + filename = f"outlier_{i}_label_{label}.png" |
| 224 | + pil_img.save(os.path.join(output_folder, filename)) |
| 225 | + |
| 226 | +print(f"Saved {len(outlier_indices)} outlier images to '{output_folder}' folder.") |
| 227 | + |
| 228 | +# Find samples closest to salient units |
| 229 | +salient_folder = "salient_digits" |
| 230 | +os.makedirs(salient_folder, exist_ok=True) |
| 231 | +salient_unit_indices = np.where(som.unit_saliency.numpy())[0] |
| 232 | + |
| 233 | +for i, unit_idx in enumerate(salient_unit_indices): |
| 234 | + # Find the sample closest to this salient unit |
| 235 | + unit_weights = som.W[unit_idx].detach().numpy() |
| 236 | + distances = np.linalg.norm(data.numpy() - unit_weights, axis=1) |
| 237 | + closest_sample_idx = np.argmin(distances) |
| 238 | + |
| 239 | + img = images[closest_sample_idx] |
| 240 | + label = labels[closest_sample_idx] |
| 241 | + |
| 242 | + # Normalize the image to 0-255 range |
| 243 | + img_normalized = ((img - img.min()) / (img.max() - img.min()) * 255).astype( |
| 244 | + np.uint8 |
| 245 | + ) |
| 246 | + |
| 247 | + # Create a PIL Image |
| 248 | + pil_img = Image.fromarray(img_normalized) |
| 249 | + |
| 250 | + # Save the image |
| 251 | + filename = f"salient_unit_{i}_label_{label}.png" |
| 252 | + pil_img.save(os.path.join(salient_folder, filename)) |
| 253 | + |
| 254 | +print( |
| 255 | + f"Saved {len(salient_unit_indices)} salient unit images to '{salient_folder}' folder." |
| 256 | +) |
0 commit comments