Spaces:
Runtime error
Runtime error
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File pram -> recdataset | |
| @IDE PyCharm | |
| @Author [email protected] | |
| @Date 29/01/2024 14:42 | |
| ==================================================''' | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| class RecDataset(Dataset): | |
| def __init__(self, sub_sets=[]): | |
| assert len(sub_sets) >= 1 | |
| self.sub_sets = sub_sets | |
| self.names = [] | |
| self.sub_set_index = [] | |
| self.seg_offsets = [] | |
| self.sub_set_item_index = [] | |
| self.dataset_names = [] | |
| self.scene_names = [] | |
| start_index_valid_seg = 1 # start from 1, 0 is for invalid | |
| total_subset = 0 | |
| for scene_set in sub_sets: # [0, n_class] | |
| name = scene_set.dataset | |
| self.names.append(name) | |
| n_samples = len(scene_set) | |
| n_class = scene_set.n_class | |
| self.seg_offsets = self.seg_offsets + [start_index_valid_seg for v in range(len(scene_set))] | |
| start_index_valid_seg = start_index_valid_seg + n_class - 1 | |
| self.sub_set_index = self.sub_set_index + [total_subset for k in range(n_samples)] | |
| self.sub_set_item_index = self.sub_set_item_index + [k for k in range(n_samples)] | |
| # self.dataset_names = self.dataset_names + [name for k in range(n_samples)] | |
| self.scene_names = self.scene_names + [name for k in range(n_samples)] | |
| total_subset += 1 | |
| self.n_class = start_index_valid_seg | |
| print('Load {} images {} segs from {} subsets from {}'.format(len(self.sub_set_item_index), self.n_class, | |
| len(sub_sets), self.names)) | |
| def __len__(self): | |
| return len(self.sub_set_item_index) | |
| def __getitem__(self, idx): | |
| subset_idx = self.sub_set_index[idx] | |
| item_idx = self.sub_set_item_index[idx] | |
| scene_name = self.scene_names[idx] | |
| out = self.sub_sets[subset_idx][item_idx] | |
| org_gt_seg = out['gt_seg'] | |
| org_gt_cls = out['gt_cls'] | |
| org_gt_cls_dist = out['gt_cls_dist'] | |
| org_gt_n_seg = out['gt_n_seg'] | |
| offset = self.seg_offsets[idx] | |
| org_n_class = self.sub_sets[subset_idx].n_class | |
| gt_seg = np.zeros(shape=(org_gt_seg.shape[0],), dtype=int) # [0, ..., n_features] | |
| gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int) | |
| gt_cls = np.zeros(shape=(self.n_class,), dtype=int) | |
| gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float) | |
| # copy invalid segments | |
| gt_n_seg[0] = org_gt_n_seg[0] | |
| gt_cls[0] = org_gt_cls[0] | |
| gt_cls_dist[0] = org_gt_cls_dist[0] | |
| # print('org: ', org_n_class, org_gt_seg.shape, org_gt_n_seg.shape, org_gt_seg) | |
| # copy valid segments | |
| gt_seg[org_gt_seg > 0] = org_gt_seg[org_gt_seg > 0] + offset - 1 # [0, ..., 1023] | |
| gt_n_seg[offset:offset + org_n_class - 1] = org_gt_n_seg[1:] # [0...,n_seg] | |
| gt_cls[offset:offset + org_n_class - 1] = org_gt_cls[1:] # [0, ..., n_seg] | |
| gt_cls_dist[offset:offset + org_n_class - 1] = org_gt_cls_dist[1:] # [0, ..., n_seg] | |
| out['gt_seg'] = gt_seg | |
| out['gt_cls'] = gt_cls | |
| out['gt_cls_dist'] = gt_cls_dist | |
| out['gt_n_seg'] = gt_n_seg | |
| # print('gt: ', org_n_class, gt_seg.shape, gt_n_seg.shape, gt_seg) | |
| out['scene_name'] = scene_name | |
| # out['org_gt_seg'] = org_gt_seg | |
| # out['org_gt_n_seg'] = org_gt_n_seg | |
| # out['org_gt_cls'] = org_gt_cls | |
| # out['org_gt_cls_dist'] = org_gt_cls_dist | |
| return out | |