|
|
import os |
|
|
import os.path as osp |
|
|
import time |
|
|
import datetime |
|
|
from argparse import ArgumentParser |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.utils.data as Data |
|
|
from tensorboardX import SummaryWriter |
|
|
from tqdm import tqdm |
|
|
|
|
|
from models import get_model |
|
|
from utils.data import * |
|
|
from utils.loss import SoftLoULoss |
|
|
from utils.lr_scheduler import * |
|
|
from utils.metrics import SegmentationMetricTPFNFP |
|
|
from utils.my_pd_fa import * |
|
|
from utils.pd_fa import * |
|
|
from utils.logger import setup_logger |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
|
|
|
|
|
|
|
|
|
parser = ArgumentParser(description='Implement of RPCANets') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--base-size', type=int, default=256, help='base size of images') |
|
|
parser.add_argument('--crop-size', type=int, default=256, help='crop size of images') |
|
|
parser.add_argument('--dataset', type=str, default='sirst', help='choose datasets') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--batch-size', type=int, default=8, help='batch size for training') |
|
|
parser.add_argument('--epochs', type=int, default=50, help='number of epochs') |
|
|
parser.add_argument('--warm-up-epochs', type=int, default=0, help='warm up epochs') |
|
|
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') |
|
|
parser.add_argument('--gpu', type=str, default='0', help='GPU number') |
|
|
parser.add_argument('--seed', type=int, default=42, help='seed') |
|
|
parser.add_argument('--lr-scheduler', type=str, default='poly', help='learning rate scheduler') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--net-name', type=str, default='rpcanet', |
|
|
help='net name: fcn') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--save-iter-step', type=int, default=1, help='save model per step iters') |
|
|
parser.add_argument('--log-per-iter', type=int, default=1, help='interval of logging') |
|
|
parser.add_argument('--base-dir', type=str, default='./result/', help='saving dir') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
args.time_name = time.strftime('%Y%m%dT%H-%M-%S', time.localtime(time.time())) |
|
|
args.folder_name = '{}_{}_{}'.format(args.time_name, args.net_name, args.dataset) |
|
|
args.save_folder = osp.join(args.base_dir, args.folder_name) |
|
|
|
|
|
|
|
|
if args.seed != 0: |
|
|
set_seeds(args.seed) |
|
|
|
|
|
|
|
|
args.logger = setup_logger("Robust PCA Networks", args.save_folder, 0, filename='log.txt') |
|
|
return args |
|
|
|
|
|
|
|
|
def set_seeds(seed): |
|
|
np.random.seed(seed) |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
|
|
|
class Trainer(object): |
|
|
def __init__(self, args): |
|
|
self.args = args |
|
|
self.iter_num = 0 |
|
|
|
|
|
|
|
|
if args.dataset == 'sirstaug': |
|
|
trainset = SirstAugDataset(base_dir=r'./datasets/sirst_aug', |
|
|
mode='train', base_size=args.base_size) |
|
|
valset = SirstAugDataset(base_dir=r'./datasets/sirst_aug', |
|
|
mode='test', base_size=args.base_size) |
|
|
elif args.dataset == 'irstd1k': |
|
|
trainset = IRSTD1kDataset(base_dir=r'./datasets/IRSTD-1k', mode='train', base_size=args.base_size) |
|
|
valset = IRSTD1kDataset(base_dir=r'./datasets/IRSTD-1k', mode='test', base_size=args.base_size) |
|
|
|
|
|
elif args.dataset == 'nudt': |
|
|
trainset = NUDTDataset(base_dir=r'./datasets/NUDT-SIRST', mode='train', base_size=args.base_size) |
|
|
valset = NUDTDataset(base_dir=r'./datasets/NUDT-SIRST', mode='test', base_size=args.base_size) |
|
|
|
|
|
elif args.dataset == 'sirst': |
|
|
trainset = SirstDataset(base_dir=r'./datasets/sirst', mode='train', base_size=args.base_size) |
|
|
valset = SirstDataset(base_dir=r'./datasets/sirst', mode='val', base_size=args.base_size) |
|
|
|
|
|
elif args.dataset == 'drive': |
|
|
trainset = DriveDatasetTrain(base_dir=r'./datasets/DRIVE', mode='train', base_size=args.base_size, patch_size=args.crop_size) |
|
|
valset = DriveDatasetTest(base_dir=r'./datasets/DRIVE', mode='test', base_size=args.base_size) |
|
|
|
|
|
elif args.dataset == 'CHASEDB1': |
|
|
trainset = CHASEDB1DatasetTrain(base_dir=r'./datasets/CHASEDB1', mode='train', base_size=args.base_size, patch_size=args.crop_size) |
|
|
valset = CHASEDB1DatasetTest(base_dir=r'./datasets/CHASEDB1', mode='test', base_size=args.base_size) |
|
|
|
|
|
elif args.dataset == 'stare': |
|
|
trainset = STAREDatasetTrain(base_dir=r'./datasets/STARE', mode='train', base_size=args.base_size, patch_size=args.crop_size) |
|
|
valset = STAREDatasetTest(base_dir=r'./datasets/STARE', mode='test', base_size=args.base_size) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
self.train_data_loader = Data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) |
|
|
self.val_data_loader = Data.DataLoader(valset, batch_size=args.batch_size, shuffle=True) |
|
|
self.iter_per_epoch = len(self.train_data_loader) |
|
|
self.max_iter = args.epochs * self.iter_per_epoch |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
|
self.device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.net = get_model(args.net_name) |
|
|
self.net = self.net.to(self.device) |
|
|
|
|
|
|
|
|
self.softiou = SoftLoULoss() |
|
|
self.mse = torch.nn.MSELoss() |
|
|
|
|
|
|
|
|
self.scheduler = LR_Scheduler_Head(args.lr_scheduler, args.lr, |
|
|
args.epochs, len(self.train_data_loader), lr_step=10) |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=args.lr) |
|
|
|
|
|
|
|
|
self.metric = SegmentationMetricTPFNFP(nclass=1) |
|
|
self.best_iou = 0 |
|
|
self.best_fmeasure = 0 |
|
|
self.eval_loss = 0 |
|
|
self.iou = 0 |
|
|
self.fmeasure = 0 |
|
|
self.eval_my_PD_FA = my_PD_FA() |
|
|
self.eval_PD_FA = PD_FA() |
|
|
|
|
|
|
|
|
self.writer = SummaryWriter(log_dir=args.save_folder) |
|
|
self.writer.add_text(args.folder_name, 'Args:%s, ' % args) |
|
|
|
|
|
|
|
|
self.logger = args.logger |
|
|
self.logger.info(args) |
|
|
self.logger.info("Using device: {}".format(self.device)) |
|
|
|
|
|
def training(self): |
|
|
|
|
|
start_time = time.time() |
|
|
base_log = "Epoch-Iter: [{:d}/{:d}]-[{:d}/{:d}] || Lr: {:.6f} || Loss: {:.4f}={:.4f}+{:.4f} || " \ |
|
|
"Cost Time: {} || Estimated Time: {}" |
|
|
for epoch in range(args.epochs): |
|
|
for i, (data, labels) in enumerate(self.train_data_loader): |
|
|
self.net.train() |
|
|
|
|
|
self.scheduler(self.optimizer, i, epoch, self.best_iou) |
|
|
|
|
|
data = data.to(self.device) |
|
|
|
|
|
labels = labels.to(self.device) |
|
|
out_D, out_T = self.net(data) |
|
|
|
|
|
loss_softiou = self.softiou(out_T, labels) |
|
|
loss_mse = self.mse(out_D, data) |
|
|
gamma = torch.Tensor([0.01]).to(self.device) |
|
|
loss_all = loss_softiou + torch.mul(gamma, loss_mse) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
loss_all.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
self.iter_num += 1 |
|
|
|
|
|
cost_string = str(datetime.timedelta(seconds=int(time.time() - start_time))) |
|
|
eta_seconds = ((time.time() - start_time) / self.iter_num) * (self.max_iter - self.iter_num) |
|
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
|
|
|
|
|
self.writer.add_scalar('Train Loss/Loss All', np.mean(loss_all.item()), self.iter_num) |
|
|
self.writer.add_scalar('Train Loss/Loss SoftIoU', np.mean(loss_softiou.item()), self.iter_num) |
|
|
self.writer.add_scalar('Train Loss/Loss MSE', np.mean(loss_mse.item()), self.iter_num) |
|
|
self.writer.add_scalar('Learning rate/', trainer.optimizer.param_groups[0]['lr'], self.iter_num) |
|
|
|
|
|
if self.iter_num % self.args.log_per_iter == 0: |
|
|
self.logger.info(base_log.format(epoch + 1, args.epochs, self.iter_num % self.iter_per_epoch, |
|
|
self.iter_per_epoch, self.optimizer.param_groups[0]['lr'], |
|
|
loss_all.item(), loss_softiou.item(), loss_mse.item(), |
|
|
cost_string, eta_string)) |
|
|
|
|
|
if (self.iter_num % args.save_iter_step) == 0 or self.iter_num % self.iter_per_epoch == 0: |
|
|
self.validation() |
|
|
|
|
|
def validation(self): |
|
|
self.metric.reset() |
|
|
self.net.eval() |
|
|
base_log = "Data: {:s}, IoU: {:.4f}/{:.4f}, F1: {:.4f}/{:.4f} " |
|
|
for i, (data, labels) in enumerate(self.val_data_loader): |
|
|
with torch.no_grad(): |
|
|
out_D, out_T = self.net(data.to(self.device)) |
|
|
out_D, out_T = out_D.cpu(), out_T.cpu() |
|
|
|
|
|
|
|
|
|
|
|
loss_softiou = self.softiou(out_T, labels) |
|
|
loss_mse = self.mse(out_D, data) |
|
|
gamma = torch.Tensor([0.01]).to(self.device) |
|
|
loss_all = loss_softiou + torch.mul(gamma, loss_mse) |
|
|
|
|
|
self.metric.update(labels, out_T) |
|
|
|
|
|
|
|
|
iou, prec, recall, fmeasure = self.metric.get() |
|
|
torch.save(self.net.state_dict(), osp.join(self.args.save_folder, 'latest.pkl')) |
|
|
if iou > self.best_iou: |
|
|
self.best_iou = iou |
|
|
torch.save(self.net.state_dict(), osp.join(self.args.save_folder, 'best.pkl')) |
|
|
if fmeasure > self.best_fmeasure: |
|
|
self.best_fmeasure = fmeasure |
|
|
|
|
|
|
|
|
self.writer.add_scalar('Test/IoU', iou, self.iter_num) |
|
|
self.writer.add_scalar('Test/F1', fmeasure, self.iter_num) |
|
|
self.writer.add_scalar('Best/IoU', self.best_iou, self.iter_num) |
|
|
self.writer.add_scalar('Best/Fmeasure', self.best_fmeasure, self.iter_num) |
|
|
self.logger.info(base_log.format(self.args.dataset, iou, self.best_iou, fmeasure, self.best_fmeasure)) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
trainer = Trainer(args) |
|
|
trainer.training() |
|
|
|
|
|
print('Best mIoU: %.5f, Best Fmeasure: %.5f\n\n' % (trainer.best_iou, trainer.best_fmeasure)) |
|
|
|