RPCANet / train.py
fengyiwu's picture
Upload 93 files
82b70d0 verified
import os
import os.path as osp
import time
import datetime
from argparse import ArgumentParser
import numpy as np
import torch
import torch.utils.data as Data
from tensorboardX import SummaryWriter
from tqdm import tqdm
from models import get_model
from utils.data import *
from utils.loss import SoftLoULoss
from utils.lr_scheduler import *
from utils.metrics import SegmentationMetricTPFNFP
from utils.my_pd_fa import *
from utils.pd_fa import *
from utils.logger import setup_logger
def parse_args():
#
# Setting parameters
#
parser = ArgumentParser(description='Implement of RPCANets')
#
# Dataset parameters
#
parser.add_argument('--base-size', type=int, default=256, help='base size of images')
parser.add_argument('--crop-size', type=int, default=256, help='crop size of images')
parser.add_argument('--dataset', type=str, default='sirst', help='choose datasets')
#
# Training parameters
#
parser.add_argument('--batch-size', type=int, default=8, help='batch size for training')
parser.add_argument('--epochs', type=int, default=50, help='number of epochs')
parser.add_argument('--warm-up-epochs', type=int, default=0, help='warm up epochs')
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
parser.add_argument('--gpu', type=str, default='0', help='GPU number')
parser.add_argument('--seed', type=int, default=42, help='seed')
parser.add_argument('--lr-scheduler', type=str, default='poly', help='learning rate scheduler')
#
# Net parameters
#
parser.add_argument('--net-name', type=str, default='rpcanet',
help='net name: fcn')
# Rank parameters
#
# parser.add_argument('--rank', type=int, default=8,
# help='rank number')
#
# Save parameters
#
parser.add_argument('--save-iter-step', type=int, default=1, help='save model per step iters')
parser.add_argument('--log-per-iter', type=int, default=1, help='interval of logging')
parser.add_argument('--base-dir', type=str, default='./result/', help='saving dir')
args = parser.parse_args()
# Save folders
args.time_name = time.strftime('%Y%m%dT%H-%M-%S', time.localtime(time.time()))
args.folder_name = '{}_{}_{}'.format(args.time_name, args.net_name, args.dataset)
args.save_folder = osp.join(args.base_dir, args.folder_name)
# seed
if args.seed != 0:
set_seeds(args.seed)
# logger
args.logger = setup_logger("Robust PCA Networks", args.save_folder, 0, filename='log.txt')
return args
def set_seeds(seed):
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
class Trainer(object):
def __init__(self, args):
self.args = args
self.iter_num = 0
## dataset
if args.dataset == 'sirstaug':
trainset = SirstAugDataset(base_dir=r'./datasets/sirst_aug',
mode='train', base_size=args.base_size)
valset = SirstAugDataset(base_dir=r'./datasets/sirst_aug',
mode='test', base_size=args.base_size)
elif args.dataset == 'irstd1k':
trainset = IRSTD1kDataset(base_dir=r'./datasets/IRSTD-1k', mode='train', base_size=args.base_size)
valset = IRSTD1kDataset(base_dir=r'./datasets/IRSTD-1k', mode='test', base_size=args.base_size)
elif args.dataset == 'nudt':
trainset = NUDTDataset(base_dir=r'./datasets/NUDT-SIRST', mode='train', base_size=args.base_size)
valset = NUDTDataset(base_dir=r'./datasets/NUDT-SIRST', mode='test', base_size=args.base_size)
elif args.dataset == 'sirst':
trainset = SirstDataset(base_dir=r'./datasets/sirst', mode='train', base_size=args.base_size)
valset = SirstDataset(base_dir=r'./datasets/sirst', mode='val', base_size=args.base_size)
elif args.dataset == 'drive':
trainset = DriveDatasetTrain(base_dir=r'./datasets/DRIVE', mode='train', base_size=args.base_size, patch_size=args.crop_size)
valset = DriveDatasetTest(base_dir=r'./datasets/DRIVE', mode='test', base_size=args.base_size)
elif args.dataset == 'CHASEDB1':
trainset = CHASEDB1DatasetTrain(base_dir=r'./datasets/CHASEDB1', mode='train', base_size=args.base_size, patch_size=args.crop_size)
valset = CHASEDB1DatasetTest(base_dir=r'./datasets/CHASEDB1', mode='test', base_size=args.base_size)
elif args.dataset == 'stare':
trainset = STAREDatasetTrain(base_dir=r'./datasets/STARE', mode='train', base_size=args.base_size, patch_size=args.crop_size)
valset = STAREDatasetTest(base_dir=r'./datasets/STARE', mode='test', base_size=args.base_size)
else:
raise NotImplementedError
self.train_data_loader = Data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
self.val_data_loader = Data.DataLoader(valset, batch_size=args.batch_size, shuffle=True)
self.iter_per_epoch = len(self.train_data_loader)
self.max_iter = args.epochs * self.iter_per_epoch
## GPU
if torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
self.device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")
## model
self.net = get_model(args.net_name)
self.net = self.net.to(self.device)
## criterion
self.softiou = SoftLoULoss()
self.mse = torch.nn.MSELoss()
## lr scheduler
self.scheduler = LR_Scheduler_Head(args.lr_scheduler, args.lr,
args.epochs, len(self.train_data_loader), lr_step=10)
## optimizer
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=args.lr)
## evaluation metrics
self.metric = SegmentationMetricTPFNFP(nclass=1)
self.best_iou = 0
self.best_fmeasure = 0
self.eval_loss = 0 # tmp values
self.iou = 0
self.fmeasure = 0
self.eval_my_PD_FA = my_PD_FA()
self.eval_PD_FA = PD_FA()
## SummaryWriter
self.writer = SummaryWriter(log_dir=args.save_folder)
self.writer.add_text(args.folder_name, 'Args:%s, ' % args)
## log info
self.logger = args.logger
self.logger.info(args)
self.logger.info("Using device: {}".format(self.device))
def training(self):
# training step
start_time = time.time()
base_log = "Epoch-Iter: [{:d}/{:d}]-[{:d}/{:d}] || Lr: {:.6f} || Loss: {:.4f}={:.4f}+{:.4f} || " \
"Cost Time: {} || Estimated Time: {}"
for epoch in range(args.epochs):
for i, (data, labels) in enumerate(self.train_data_loader):
self.net.train()
self.scheduler(self.optimizer, i, epoch, self.best_iou)
data = data.to(self.device)
labels = labels.to(self.device)
out_D, out_T = self.net(data)
loss_softiou = self.softiou(out_T, labels)
loss_mse = self.mse(out_D, data)
gamma = torch.Tensor([0.01]).to(self.device)
loss_all = loss_softiou + torch.mul(gamma, loss_mse)
self.optimizer.zero_grad()
loss_all.backward()
self.optimizer.step()
self.iter_num += 1
cost_string = str(datetime.timedelta(seconds=int(time.time() - start_time)))
eta_seconds = ((time.time() - start_time) / self.iter_num) * (self.max_iter - self.iter_num)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self.writer.add_scalar('Train Loss/Loss All', np.mean(loss_all.item()), self.iter_num)
self.writer.add_scalar('Train Loss/Loss SoftIoU', np.mean(loss_softiou.item()), self.iter_num)
self.writer.add_scalar('Train Loss/Loss MSE', np.mean(loss_mse.item()), self.iter_num)
self.writer.add_scalar('Learning rate/', trainer.optimizer.param_groups[0]['lr'], self.iter_num)
if self.iter_num % self.args.log_per_iter == 0:
self.logger.info(base_log.format(epoch + 1, args.epochs, self.iter_num % self.iter_per_epoch,
self.iter_per_epoch, self.optimizer.param_groups[0]['lr'],
loss_all.item(), loss_softiou.item(), loss_mse.item(),
cost_string, eta_string))
if (self.iter_num % args.save_iter_step) == 0 or self.iter_num % self.iter_per_epoch == 0:
self.validation()
def validation(self):
self.metric.reset()
self.net.eval()
base_log = "Data: {:s}, IoU: {:.4f}/{:.4f}, F1: {:.4f}/{:.4f} "
for i, (data, labels) in enumerate(self.val_data_loader):
with torch.no_grad():
out_D, out_T = self.net(data.to(self.device))
out_D, out_T = out_D.cpu(), out_T.cpu()
loss_softiou = self.softiou(out_T, labels)
loss_mse = self.mse(out_D, data)
gamma = torch.Tensor([0.01]).to(self.device)
loss_all = loss_softiou + torch.mul(gamma, loss_mse)
self.metric.update(labels, out_T)
iou, prec, recall, fmeasure = self.metric.get()
torch.save(self.net.state_dict(), osp.join(self.args.save_folder, 'latest.pkl'))
if iou > self.best_iou:
self.best_iou = iou
torch.save(self.net.state_dict(), osp.join(self.args.save_folder, 'best.pkl'))
if fmeasure > self.best_fmeasure:
self.best_fmeasure = fmeasure
self.writer.add_scalar('Test/IoU', iou, self.iter_num)
self.writer.add_scalar('Test/F1', fmeasure, self.iter_num)
self.writer.add_scalar('Best/IoU', self.best_iou, self.iter_num)
self.writer.add_scalar('Best/Fmeasure', self.best_fmeasure, self.iter_num)
self.logger.info(base_log.format(self.args.dataset, iou, self.best_iou, fmeasure, self.best_fmeasure))
if __name__ == '__main__':
args = parse_args()
trainer = Trainer(args)
trainer.training()
print('Best mIoU: %.5f, Best Fmeasure: %.5f\n\n' % (trainer.best_iou, trainer.best_fmeasure))