Spaces:
Build error
Build error
| import gradio as gr | |
| import mathutils | |
| import math | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import matplotlib.cm as cmx | |
| import os.path as osp | |
| import h5py | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| from GDANet_cls import GDANET | |
| from DGCNN import DGCNN | |
| with open('shape_names.txt') as f: | |
| CLASS_NAME = f.read().splitlines() | |
| model_gda = GDANET() | |
| model_gda = nn.DataParallel(model_gda) | |
| model_gda.load_state_dict(torch.load('./GDANet_WOLFMix.t7', map_location=torch.device('cpu'))) | |
| # model_gda.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/GDANet_WOLFMix.t7', map_location=torch.device('cpu'))) | |
| model_gda.eval() | |
| model_dgcnn = DGCNN() | |
| model_dgcnn = nn.DataParallel(model_dgcnn) | |
| model_dgcnn.load_state_dict(torch.load('./dgcnn.t7', map_location=torch.device('cpu'))) | |
| # model_dgcnn.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/dgcnn.t7', map_location=torch.device('cpu'))) | |
| model_dgcnn.eval() | |
| def pyplot_draw_point_cloud(points, corruption): | |
| rot1 = mathutils.Euler([-math.pi / 2, 0, 0]).to_matrix().to_3x3() | |
| rot2 = mathutils.Euler([0, 0, math.pi]).to_matrix().to_3x3() | |
| points = np.dot(points, rot1) | |
| points = np.dot(points, rot2) | |
| x, y, z = points[:, 0], points[:, 1], points[:, 2] | |
| colorsMap = 'winter' | |
| cs = y | |
| cm = plt.get_cmap(colorsMap) | |
| cNorm = matplotlib.colors.Normalize(vmin=-1, vmax=1) | |
| scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm) | |
| fig = plt.figure(figsize=(5, 5)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| ax.scatter(x, y, z, c=scalarMap.to_rgba(cs)) | |
| scalarMap.set_array(cs) | |
| ax.set_xlim(-1, 1) | |
| ax.set_ylim(-1, 1) | |
| ax.set_zlim(-1, 1) | |
| plt.axis('off') | |
| plt.title(corruption, fontsize=30) | |
| plt.tight_layout() | |
| plt.savefig('visualization.png', bbox_inches='tight', dpi=200) | |
| plt.close() | |
| def load_dataset(corruption_idx, severity): | |
| corruptions = [ | |
| 'clean', | |
| 'scale', | |
| 'jitter', | |
| 'rotate', | |
| 'dropout_global', | |
| 'dropout_local', | |
| 'add_global', | |
| 'add_local', | |
| ] | |
| corruption_type = corruptions[corruption_idx] | |
| if corruption_type == 'clean': | |
| f = h5py.File(osp.join('modelnet_c', corruption_type + '.h5')) | |
| # f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '.h5')) | |
| else: | |
| f = h5py.File(osp.join('modelnet_c', corruption_type + '_{}'.format(severity-1) + '.h5')) | |
| # f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '_{}'.format(severity - 1) + '.h5')) | |
| data = f['data'][:].astype('float32') | |
| label = f['label'][:].astype('int64') | |
| f.close() | |
| return data, label | |
| def recognize_pcd(model, pcd): | |
| pcd = torch.tensor(pcd).unsqueeze(0) | |
| pcd = pcd.permute(0, 2, 1) | |
| output = model(pcd) | |
| prediction = output.softmax(-1).flatten() | |
| _, top5_idx = torch.topk(prediction, 5) | |
| return {CLASS_NAME[i]: float(prediction[i]) for i in top5_idx.tolist()} | |
| def run(seed, corruption_idx, severity): | |
| data, label = load_dataset(corruption_idx, severity) | |
| random.seed(seed) | |
| sample_indx = random.randint(0, data.shape[0]) | |
| pcd, cls = data[sample_indx], label[sample_indx] | |
| pyplot_draw_point_cloud(pcd, CLASS_NAME[cls[0]]) | |
| output = 'visualization.png' | |
| return output, recognize_pcd(model_dgcnn, pcd), recognize_pcd(model_gda, pcd) | |
| description = """ | |
| Welcome to the demo of PointCloud-C! [PointCloud-C](https://pointcloud-c.github.io/home.html) is a test-suite for point cloud robustness analysis under corruptions. In this demo, you may: | |
| - __Visualize__ various types of corrupted point clouds in [ModelNet-C](https://github.com/jiawei-ren/ModelNet-C). | |
| - __Compare__ our proposed techniques to the baseline in terms of prediction robustness. | |
| For more details, checkout our paper [Benchmarking and Analyzing Point Cloud Classification under Corruptions, __ICML 2022__](https://arxiv.org/abs/2202.03377)! | |
| 📣 News: [The first PointCloud-C challenge](https://codalab.lisn.upsaclay.fr/competitions/6437) with Classification track and Part Segmentation track in [ECCV'22 SenseHuman workshop](https://sense-human.github.io/) is open for submission now! | |
| """ | |
| if __name__ == '__main__': | |
| iface = gr.Interface( | |
| fn=run, | |
| inputs=[ | |
| gr.components.Number(label='Sample Seed', precision=0), | |
| gr.components.Radio( | |
| ['Clean', 'Scale', 'Jitter', 'Rotate', 'Drop Global', 'Drop Local', 'Add Global', 'Add Local'], | |
| value='Jitter', type="index", label='Corruption Type'), | |
| gr.components.Slider(1, 5, value=5, step=1, label='Corruption severity'), | |
| ], | |
| outputs=[ | |
| gr.components.Image(type="file", label="Visualization"), | |
| gr.components.Label(num_top_classes=5, label="Baseline (DGCNN) Prediction"), | |
| gr.components.Label(num_top_classes=5, label="Ours (GDANet+WolfMix) Prediction") | |
| ], | |
| live=False, | |
| allow_flagging='never', | |
| title="PointCloud-C", | |
| description=description, | |
| examples=[ | |
| [0, 'Jitter', 5], | |
| [999, 'Drop Local', 5], | |
| ], | |
| css=".output-image, .image-preview {height: 100px !important}", | |
| article="<p style='text-align: center'><a href='https://github.com/ldkong1205/PointCloud-C' target='_blank'>PointNet-C @ GitHub</a></p> " | |
| ) | |
| iface.launch() | |