| import pickle |
| from collections import Counter |
| from itertools import product |
|
|
| import matplotlib |
| import matplotlib.patches as patches |
| import numpy as np |
| import torchvision.transforms as transforms |
| from matplotlib import gridspec |
| from matplotlib import pyplot as plt |
| from matplotlib.patches import ConnectionPatch, ConnectionStyle |
| from PIL import Image |
|
|
| connectionstyle = ConnectionStyle("Arc3, rad=0.2") |
|
|
| display_transform = transforms.Compose( |
| [transforms.Resize(240), transforms.CenterCrop((240, 240))] |
| ) |
| display_transform_knn = transforms.Compose( |
| [transforms.Resize(256), transforms.CenterCrop((224, 224))] |
| ) |
|
|
|
|
| def keep_top_k(input_array, K=5): |
| """ |
| return top 5 (k) from numpy array |
| """ |
| top_5 = np.sort(input_array.reshape(-1))[::-1][K - 1] |
| masked = np.zeros_like(input_array) |
| masked[input_array >= top_5] = 1 |
| return masked |
|
|
|
|
| def arg_topK(inputarray, topK=5): |
| """ |
| returns indicies related to top K element (largest) |
| """ |
| return np.argsort(inputarray.T.reshape(-1))[::-1][:topK] |
|
|
|
|
| def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True): |
| """ |
| visualize chm results from a reranker output dict |
| """ |
|
|
| |
| cmap = matplotlib.cm.get_cmap("gist_rainbow") |
| rgba = cmap(0.5) |
| colors = [] |
| for k in range(5): |
| colors.append(cmap(k / 5.0)) |
|
|
| |
| A = np.linspace(1 + 17, 240 - 17 - 1, 7) |
| point_list = list(product(A, A)) |
|
|
| nrow = 4 |
| ncol = 7 |
|
|
| fig = plt.figure(figsize=(32, 18)) |
| gs = gridspec.GridSpec( |
| nrow, |
| ncol, |
| width_ratios=[1, 0.2, 1, 1, 1, 1, 1], |
| height_ratios=[1, 1, 1, 1], |
| wspace=0.1, |
| hspace=0.1, |
| top=0.9, |
| bottom=0.05, |
| left=0.17, |
| right=0.845, |
| ) |
| axes = [[None for n in range(ncol - 1)] for x in range(nrow)] |
|
|
| for i in range(4): |
| axes[i] = [] |
| for j in range(7): |
| if j != 1: |
| if (i, j) in [(2, 0), (3, 0)]: |
| axes[i].append(new_ax) |
| else: |
| new_ax = plt.subplot(gs[i, j]) |
| new_ax.set_xticklabels([]) |
| new_ax.set_xticks([]) |
| new_ax.set_yticklabels([]) |
| new_ax.set_yticks([]) |
| new_ax.axis("off") |
| axes[i].append(new_ax) |
|
|
| |
| axes[0][0].imshow( |
| display_transform(Image.open(reranker_output["q"]).convert("RGB")) |
| ) |
| axes[0][0].set_title( |
| f'Query - K={reranker_output["K"]}, N={reranker_output["N"]}', fontsize=21 |
| ) |
|
|
| axes[1][0].imshow( |
| display_transform(Image.open(reranker_output["q"]).convert("RGB")) |
| ) |
| axes[1][0].set_title(f'Query - K={reranker_output["K"]}', fontsize=21) |
|
|
| |
|
|
| |
| for i in range(min(5, reranker_output["chm-prediction-confidence"])): |
| axes[0][1 + i].imshow( |
| display_transform( |
| Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB") |
| ) |
| ) |
| axes[0][1 + i].set_title(f"CHM-Corr - Top - {i+1}", fontsize=21) |
|
|
| if reranker_output["chm-prediction-confidence"] < 5: |
| for i in range(reranker_output["chm-prediction-confidence"], 5): |
| axes[0][1 + i].imshow(Image.new(mode="RGB", size=(224, 224), color="white")) |
| axes[0][1 + i].set_title(f"", fontsize=21) |
|
|
| |
| for i in range(min(5, reranker_output["knn-prediction-confidence"])): |
| axes[1][1 + i].imshow( |
| display_transform_knn( |
| Image.open(reranker_output["knn-nearest-neighbors"][i]).convert("RGB") |
| ) |
| ) |
| axes[1][1 + i].set_title(f"kNN - Top - {i+1}", fontsize=21) |
|
|
| if reranker_output["knn-prediction-confidence"] < 5: |
| for i in range(reranker_output["knn-prediction-confidence"], 5): |
| axes[1][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white")) |
| axes[1][1 + i].set_title(f"", fontsize=21) |
|
|
| for i in range(min(5, reranker_output["chm-prediction-confidence"])): |
| axes[2][i + 1].imshow( |
| display_transform(Image.open(reranker_output["q"]).convert("RGB")) |
| ) |
|
|
| |
| for i in range(min(5, reranker_output["chm-prediction-confidence"])): |
| axes[3][1 + i].imshow( |
| display_transform( |
| Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB") |
| ) |
| ) |
|
|
| if reranker_output["chm-prediction-confidence"] < 5: |
| for i in range(reranker_output["chm-prediction-confidence"], 5): |
| axes[2][i + 1].imshow(Image.new(mode="RGB", size=(240, 240), color="white")) |
| axes[3][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white")) |
|
|
| nzm = reranker_output["non_zero_mask"] |
| |
|
|
| |
| if draw_box: |
| |
| for NC in range(min(5, reranker_output["chm-prediction-confidence"])): |
| |
| valid_patches_source = arg_topK( |
| reranker_output["masked_cos_values"][NC], topK=nzm |
| ) |
|
|
| |
| target_masked_patches = arg_topK( |
| reranker_output["masked_cos_values"][NC], topK=nzm |
| ) |
| valid_patches_target = [ |
| reranker_output["correspondance_map"][NC][x] |
| for x in target_masked_patches |
| ] |
| valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target] |
|
|
| patch_colors = [c for c in colors] |
| overlaps = [ |
| item |
| for item, count in Counter(valid_patches_target).items() |
| if count > 1 |
| ] |
|
|
| for O in overlaps: |
| indices = [i for i, val in enumerate(valid_patches_target) if val == O] |
| for ii in indices[1:]: |
| patch_colors[ii] = patch_colors[indices[0]] |
|
|
| for i in valid_patches_source: |
| Psource = point_list[i] |
| rect = patches.Rectangle( |
| (Psource[0] - 16, Psource[1] - 16), |
| 32, |
| 32, |
| linewidth=2, |
| edgecolor=patch_colors[valid_patches_source.tolist().index(i)], |
| facecolor="none", |
| alpha=1, |
| ) |
| axes[2][1 + NC].add_patch(rect) |
|
|
| for i in valid_patches_target: |
| Psource = point_list[i] |
| rect = patches.Rectangle( |
| (Psource[0] - 16, Psource[1] - 16), |
| 32, |
| 32, |
| linewidth=2, |
| edgecolor=patch_colors[valid_patches_target.index(i)], |
| facecolor="none", |
| alpha=1, |
| ) |
| axes[3][1 + NC].add_patch(rect) |
|
|
| |
| |
| if draw_arcs: |
| for CK in range(min(5, reranker_output["chm-prediction-confidence"])): |
| target_keypoints = [] |
| topk_index = arg_topK(reranker_output["masked_cos_values"][CK], topK=nzm) |
| for i in range(nzm): |
| con = ConnectionPatch( |
| xyA=( |
| reranker_output["src-keypoints"][CK][i, 0], |
| reranker_output["src-keypoints"][CK][i, 1], |
| ), |
| xyB=( |
| reranker_output["tgt-keypoints"][CK][i, 0], |
| reranker_output["tgt-keypoints"][CK][i, 1], |
| ), |
| coordsA="data", |
| coordsB="data", |
| axesA=axes[2][1 + CK], |
| axesB=axes[3][1 + CK], |
| color=colors[i], |
| connectionstyle=connectionstyle, |
| shrinkA=1.0, |
| shrinkB=1.0, |
| linewidth=1, |
| ) |
|
|
| axes[3][1 + CK].add_artist(con) |
|
|
| |
| axes[2][1 + CK].scatter( |
| reranker_output["src-keypoints"][CK][:, 0], |
| reranker_output["src-keypoints"][CK][:, 1], |
| c=colors[:nzm], |
| s=10, |
| ) |
| axes[3][1 + CK].scatter( |
| reranker_output["tgt-keypoints"][CK][:, 0], |
| reranker_output["tgt-keypoints"][CK][:, 1], |
| c=colors[:nzm], |
| s=10, |
| ) |
|
|
| fig.text( |
| 0.5, |
| 0.95, |
| f"CHM-Corr Prediction: {reranker_output['chm-prediction']}", |
| ha="center", |
| va="bottom", |
| color="black", |
| fontsize=22, |
| ) |
|
|
| return fig |
|
|
|
|
| def plot_from_reranker_corrmap(reranker_output, draw_box=True): |
| """ |
| visualize chm results from a reranker output dict |
| """ |
|
|
| |
| cmap = matplotlib.cm.get_cmap("gist_rainbow") |
| rgba = cmap(0.5) |
| colors = [] |
| for k in range(5): |
| colors.append(cmap(k / 5.0)) |
|
|
| |
| A = np.linspace(1 + 17, 240 - 17 - 1, 7) |
| point_list = list(product(A, A)) |
|
|
| fig, axes = plt.subplots( |
| 2, |
| 7, |
| figsize=(25, 8), |
| gridspec_kw={ |
| "wspace": 0, |
| "hspace": 0, |
| "width_ratios": [1, 0.28, 1, 1, 1, 1, 1], |
| }, |
| facecolor=(1, 1, 1), |
| ) |
|
|
| for i in range(2): |
| for j in range(7): |
| axes[i][j].axis("off") |
|
|
| axes[0][0].imshow( |
| display_transform(Image.open(reranker_output["q"]).convert("RGB")) |
| ) |
|
|
| for i in range(min(5, reranker_output["chm-prediction-confidence"])): |
| axes[0][2 + i].imshow( |
| display_transform(Image.open(reranker_output["q"]).convert("RGB")) |
| ) |
|
|
| |
| for i in range(min(5, reranker_output["chm-prediction-confidence"])): |
| axes[1][2 + i].imshow( |
| display_transform( |
| Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB") |
| ) |
| ) |
|
|
| if reranker_output["chm-prediction-confidence"] < 5: |
| for i in range(reranker_output["chm-prediction-confidence"], 5): |
| axes[0][2 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white")) |
| axes[1][2 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white")) |
|
|
| nzm = reranker_output["non_zero_mask"] |
| |
|
|
| |
| if draw_box: |
| |
| for NC in range(min(5, reranker_output["chm-prediction-confidence"])): |
| |
| valid_patches_source = arg_topK( |
| reranker_output["masked_cos_values"][NC], topK=nzm |
| ) |
|
|
| |
| target_masked_patches = arg_topK( |
| reranker_output["masked_cos_values"][NC], topK=nzm |
| ) |
| valid_patches_target = [ |
| reranker_output["correspondance_map"][NC][x] |
| for x in target_masked_patches |
| ] |
| valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target] |
|
|
| patch_colors = [c for c in colors] |
| overlaps = [ |
| item |
| for item, count in Counter(valid_patches_target).items() |
| if count > 1 |
| ] |
|
|
| for O in overlaps: |
| indices = [i for i, val in enumerate(valid_patches_target) if val == O] |
| for ii in indices[1:]: |
| patch_colors[ii] = patch_colors[indices[0]] |
|
|
| for i in valid_patches_source: |
| Psource = point_list[i] |
| rect = patches.Rectangle( |
| (Psource[0] - 16, Psource[1] - 16), |
| 32, |
| 32, |
| linewidth=2, |
| edgecolor=patch_colors[valid_patches_source.tolist().index(i)], |
| facecolor="none", |
| alpha=1, |
| ) |
| axes[0][2 + NC].add_patch(rect) |
|
|
| for i in valid_patches_target: |
| Psource = point_list[i] |
| rect = patches.Rectangle( |
| (Psource[0] - 16, Psource[1] - 16), |
| 32, |
| 32, |
| linewidth=2, |
| edgecolor=patch_colors[valid_patches_target.index(i)], |
| facecolor="none", |
| alpha=1, |
| ) |
| axes[1][2 + NC].add_patch(rect) |
|
|
| return fig, reranker_output["chm-prediction"] |
|
|