Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import numpy as np | |
| import torch | |
| from torch.nn.functional import normalize | |
| from . import get_model | |
| from models.base import BaseModel | |
| from models.bev_net import BEVNet | |
| from models.bev_projection import CartesianProjection, PolarProjectionDepth | |
| from models.voting import ( | |
| argmax_xyr,argmax_xyrh, | |
| conv2d_fft_batchwise, | |
| expectation_xyr, | |
| log_softmax_spatial, | |
| mask_yaw_prior, | |
| nll_loss_xyr, | |
| nll_loss_xyr_smoothed, | |
| TemplateSampler, | |
| UAVTemplateSampler, | |
| UAVTemplateSamplerFast | |
| ) | |
| import torch.nn.functional as F | |
| from torch.nn.functional import grid_sample, log_softmax, pad | |
| from .map_encoder import MapEncoder | |
| from .map_encoder_single import MapEncoderSingle | |
| from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall | |
| class MapLocNet(BaseModel): | |
| default_conf = { | |
| "image_size": "???", | |
| "val_citys":"???", | |
| "image_encoder": "???", | |
| "map_encoder": "???", | |
| "bev_net": "???", | |
| "latent_dim": "???", | |
| "matching_dim": "???", | |
| "scale_range": [0, 9], | |
| "num_scale_bins": "???", | |
| "z_min": None, | |
| "z_max": "???", | |
| "x_max": "???", | |
| "pixel_per_meter": "???", | |
| "num_rotations": "???", | |
| "add_temperature": False, | |
| "normalize_features": False, | |
| "padding_matching": "replicate", | |
| "apply_map_prior": True, | |
| "do_label_smoothing": False, | |
| "sigma_xy": 1, | |
| "sigma_r": 2, | |
| # depcreated | |
| "depth_parameterization": "scale", | |
| "norm_depth_scores": False, | |
| "normalize_scores_by_dim": False, | |
| "normalize_scores_by_num_valid": True, | |
| "prior_renorm": True, | |
| "retrieval_dim": None, | |
| } | |
| def _init(self, conf): | |
| assert not self.conf.norm_depth_scores | |
| assert self.conf.depth_parameterization == "scale" | |
| assert not self.conf.normalize_scores_by_dim | |
| assert self.conf.normalize_scores_by_num_valid | |
| assert self.conf.prior_renorm | |
| # a=conf.image_encoder.get("name", "feature_extractor_v2") | |
| # b=conf.image_encoder.get("name") | |
| Encoder = get_model(conf.image_encoder.get("name")) | |
| self.image_encoder = Encoder(conf.image_encoder.backbone) | |
| if len(conf.map_encoder.num_classes)==1: | |
| self.map_encoder = MapEncoderSingle(conf.map_encoder) | |
| else: | |
| self.map_encoder = MapEncoder(conf.map_encoder) | |
| # self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net) | |
| ppm = conf.pixel_per_meter | |
| # self.projection_polar = PolarProjectionDepth( | |
| # conf.z_max, | |
| # ppm, | |
| # conf.scale_range, | |
| # conf.z_min, | |
| # ) | |
| # self.projection_bev = CartesianProjection( | |
| # conf.z_max, conf.x_max, ppm, conf.z_min | |
| # ) | |
| # self.template_sampler = TemplateSampler( | |
| # self.projection_bev.grid_xz, ppm, conf.num_rotations | |
| # ) | |
| self.template_sampler = UAVTemplateSamplerFast(conf.num_rotations,w=conf.image_size//2) | |
| # self.template_sampler = UAVTemplateSampler(conf.num_rotations) | |
| # self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins) | |
| # if conf.bev_net is None: | |
| # self.feature_projection = torch.nn.Linear( | |
| # conf.latent_dim, conf.matching_dim | |
| # ) | |
| if conf.add_temperature: | |
| temperature = torch.nn.Parameter(torch.tensor(0.0)) | |
| self.register_parameter("temperature", temperature) | |
| def exhaustive_voting(self, f_bev, f_map): | |
| if self.conf.normalize_features: | |
| f_bev = normalize(f_bev, dim=1) | |
| f_map = normalize(f_map, dim=1) | |
| # Build the templates and exhaustively match against the map. | |
| # if confidence_bev is not None: | |
| # f_bev = f_bev * confidence_bev.unsqueeze(1) | |
| # f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0) | |
| # torch.save(f_bev, 'f_bev.pt') | |
| # torch.save(f_map, 'f_map.pt') | |
| f_map = F.interpolate(f_map, size=(256, 256), mode='bilinear', align_corners=False) | |
| templates = self.template_sampler(f_bev)#[batch,256,8,129,129] | |
| # torch.save(templates, 'templates.pt') | |
| with torch.autocast("cuda", enabled=False): | |
| scores = conv2d_fft_batchwise( | |
| f_map.float(), | |
| templates.float(), | |
| padding_mode=self.conf.padding_matching, | |
| ) | |
| if self.conf.add_temperature: | |
| scores = scores * torch.exp(self.temperature) | |
| # Reweight the different rotations based on the number of valid pixels | |
| # in each template. Axis-aligned rotation have the maximum number of valid pixels. | |
| # valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4) | |
| # num_valid = valid_templates.float().sum((-3, -2, -1)) | |
| # scores = scores / num_valid[..., None, None] | |
| return scores | |
| def _forward(self, data): | |
| pred = {} | |
| pred_map = pred["map"] = self.map_encoder(data) | |
| f_map = pred_map["map_features"][0]#[batch,8,256,256] | |
| # Extract image features. | |
| level = 0 | |
| f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176] | |
| # print("f_map:",f_map.shape) | |
| scores = self.exhaustive_voting(f_image, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129] | |
| scores = scores.moveaxis(1, -1) # B,H,W,N | |
| if "log_prior" in pred_map and self.conf.apply_map_prior: | |
| scores = scores + pred_map["log_prior"][0].unsqueeze(-1) | |
| # pred["scores_unmasked"] = scores.clone() | |
| if "map_mask" in data: | |
| scores.masked_fill_(~data["map_mask"][..., None], -np.inf) | |
| if "yaw_prior" in data: | |
| mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations) | |
| log_probs = log_softmax_spatial(scores) | |
| # torch.save(scores, 'scores.pt') | |
| with torch.no_grad(): | |
| uvr_max = argmax_xyr(scores).to(scores) | |
| uvr_avg, _ = expectation_xyr(log_probs.exp()) | |
| return { | |
| **pred, | |
| "scores": scores, | |
| "log_probs": log_probs, | |
| "uvr_max": uvr_max, | |
| "uv_max": uvr_max[..., :2], | |
| "yaw_max": uvr_max[..., 2], | |
| "uvr_expectation": uvr_avg, | |
| "uv_expectation": uvr_avg[..., :2], | |
| "yaw_expectation": uvr_avg[..., 2], | |
| "features_image": f_image, | |
| } | |
| def _forward_scale(self, data,resize=None): | |
| pred = {} | |
| pred_map = pred["map"] = self.map_encoder(data) | |
| f_map = pred_map["map_features"][0]#[batch,8,256,256] | |
| # Extract image features. | |
| level = 0 | |
| f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176] | |
| # print("f_map:",f_map.shape) | |
| scores_list = [] | |
| for resize_size in resize: | |
| f_image_re = torch.nn.functional.interpolate(f_image, size=resize_size, mode='bilinear', align_corners=False) | |
| scores = self.exhaustive_voting(f_image_re, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129] | |
| scores = scores.moveaxis(1, -1) # B,H,W,N | |
| scores_list.append(scores) | |
| scores_list = torch.stack(scores_list, dim=-1) | |
| log_probs_list = log_softmax(scores_list.flatten(-4), dim=-1).reshape(scores_list.shape) | |
| # if "log_prior" in pred_map and self.conf.apply_map_prior: | |
| # scores = scores + pred_map["log_prior"][0].unsqueeze(-1) | |
| # # pred["scores_unmasked"] = scores.clone() | |
| # if "map_mask" in data: | |
| # scores.masked_fill_(~data["map_mask"][..., None], -np.inf) | |
| # if "yaw_prior" in data: | |
| # mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations) | |
| #scores shape:[batch,W,H,64] | |
| # log_probs = log_softmax_spatial(scores) | |
| # torch.save(scores, 'scores.pt') | |
| with torch.no_grad(): | |
| uvr_max = argmax_xyrh(scores_list) | |
| # uvr_avg, _ = expectation_xyr(log_probs_list.exp()) | |
| uvr_avg= uvr_max | |
| return { | |
| **pred, | |
| "scores": scores, | |
| "log_probs": log_probs_list, | |
| "uvr_max": uvr_max, | |
| "uv_max": uvr_max[..., :2], | |
| "yaw_max": uvr_max[..., 2], | |
| "uvr_expectation": uvr_avg, | |
| "uv_expectation": uvr_avg[..., :2], | |
| "yaw_expectation": uvr_avg[..., 2], | |
| "features_image": f_image, | |
| } | |
| def loss(self, pred, data): | |
| xy_gt = data["uv"] | |
| yaw_gt = data["roll_pitch_yaw"][..., -1] | |
| if self.conf.do_label_smoothing: | |
| nll = nll_loss_xyr_smoothed( | |
| pred["log_probs"], | |
| xy_gt, | |
| yaw_gt, | |
| self.conf.sigma_xy / self.conf.pixel_per_meter, | |
| self.conf.sigma_r, | |
| mask=data.get("map_mask"), | |
| ) | |
| else: | |
| nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt) | |
| loss = {"total": nll, "nll": nll} | |
| if self.training and self.conf.add_temperature: | |
| loss["temperature"] = self.temperature.expand(len(nll)) | |
| return loss | |
| def metrics(self): | |
| return { | |
| "xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter), | |
| "xy_expectation_error": Location2DError( | |
| "uv_expectation", self.conf.pixel_per_meter | |
| ), | |
| "yaw_max_error": AngleError("yaw_max"), | |
| "xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), | |
| "xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), | |
| "xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), | |
| # "x_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), | |
| # "x_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), | |
| # "x_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), | |
| # | |
| # "y_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), | |
| # "y_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), | |
| # "y_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), | |
| "yaw_recall_1°": AngleRecall(1.0, "yaw_max"), | |
| "yaw_recall_3°": AngleRecall(3.0, "yaw_max"), | |
| "yaw_recall_5°": AngleRecall(5.0, "yaw_max"), | |
| } | |