276 lines
11 KiB
Python
276 lines
11 KiB
Python
from collections import OrderedDict
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import lr_scheduler
|
|
from torch.optim import Adam
|
|
|
|
from models.select_network import define_G
|
|
from models.model_base import ModelBase
|
|
from models.loss import CharbonnierLoss, TVLoss
|
|
from models.loss_ssim import SSIMLoss
|
|
|
|
from utils.utils_model import test_mode
|
|
from utils.utils_regularizers import regularizer_orth, regularizer_clip
|
|
|
|
|
|
class ModelPlain(ModelBase):
|
|
"""Train with pixel loss"""
|
|
def __init__(self, opt):
|
|
super(ModelPlain, self).__init__(opt)
|
|
# ------------------------------------
|
|
# define network
|
|
# ------------------------------------
|
|
self.opt_train = self.opt['train'] # training option
|
|
self.netG = define_G(opt)
|
|
self.netG = self.model_to_device(self.netG)
|
|
if self.opt_train['E_decay'] > 0:
|
|
self.netE = define_G(opt).to(self.device).eval()
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Preparation before training with data
|
|
# Save model during training
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
# ----------------------------------------
|
|
# initialize training
|
|
# ----------------------------------------
|
|
def init_train(self):
|
|
self.load() # load model
|
|
self.netG.train() # set training mode,for BN
|
|
self.define_loss() # define loss
|
|
self.define_optimizer() # define optimizer
|
|
self.load_optimizers() # load optimizer
|
|
self.define_scheduler() # define scheduler
|
|
self.log_dict = OrderedDict() # log
|
|
|
|
# ----------------------------------------
|
|
# load pre-trained G model
|
|
# ----------------------------------------
|
|
def load(self):
|
|
load_path_G = self.opt['path']['pretrained_netG']
|
|
if load_path_G is not None:
|
|
print('Loading model for G [{:s}] ...'.format(load_path_G))
|
|
self.load_network(load_path_G, self.netG, strict=self.opt_train['G_param_strict'], param_key='params')
|
|
load_path_E = self.opt['path']['pretrained_netE']
|
|
if self.opt_train['E_decay'] > 0:
|
|
if load_path_E is not None:
|
|
print('Loading model for E [{:s}] ...'.format(load_path_E))
|
|
self.load_network(load_path_E, self.netE, strict=self.opt_train['E_param_strict'], param_key='params_ema')
|
|
else:
|
|
print('Copying model for E ...')
|
|
self.update_E(0)
|
|
self.netE.eval()
|
|
|
|
# ----------------------------------------
|
|
# load optimizer
|
|
# ----------------------------------------
|
|
def load_optimizers(self):
|
|
load_path_optimizerG = self.opt['path']['pretrained_optimizerG']
|
|
if load_path_optimizerG is not None and self.opt_train['G_optimizer_reuse']:
|
|
print('Loading optimizerG [{:s}] ...'.format(load_path_optimizerG))
|
|
self.load_optimizer(load_path_optimizerG, self.G_optimizer)
|
|
|
|
# ----------------------------------------
|
|
# save model / optimizer(optional)
|
|
# ----------------------------------------
|
|
def save(self, iter_label):
|
|
self.save_network(self.save_dir, self.netG, 'G', iter_label)
|
|
if self.opt_train['E_decay'] > 0:
|
|
self.save_network(self.save_dir, self.netE, 'E', iter_label)
|
|
if self.opt_train['G_optimizer_reuse']:
|
|
self.save_optimizer(self.save_dir, self.G_optimizer, 'optimizerG', iter_label)
|
|
|
|
# ----------------------------------------
|
|
# define loss
|
|
# ----------------------------------------
|
|
def define_loss(self):
|
|
G_lossfn_type = self.opt_train['G_lossfn_type']
|
|
if G_lossfn_type == 'l1':
|
|
self.G_lossfn = nn.L1Loss().to(self.device)
|
|
elif G_lossfn_type == 'l2':
|
|
self.G_lossfn = nn.MSELoss().to(self.device)
|
|
elif G_lossfn_type == 'l2sum':
|
|
self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
|
|
elif G_lossfn_type == 'ssim':
|
|
self.G_lossfn = SSIMLoss().to(self.device)
|
|
elif G_lossfn_type == 'charbonnier':
|
|
self.G_lossfn = CharbonnierLoss(self.opt_train['G_charbonnier_eps']).to(self.device)
|
|
elif G_lossfn_type == 'tv':
|
|
self.G_lossfn = TVLoss(self.opt_train['G_tvloss_weight']).to(self.device)
|
|
else:
|
|
raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
|
|
self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
|
|
|
|
# ----------------------------------------
|
|
# define optimizer
|
|
# ----------------------------------------
|
|
def define_optimizer(self):
|
|
G_optim_params = []
|
|
for k, v in self.netG.named_parameters():
|
|
if v.requires_grad:
|
|
G_optim_params.append(v)
|
|
else:
|
|
print('Params [{:s}] will not optimize.'.format(k))
|
|
if self.opt_train['G_optimizer_type'] == 'adam':
|
|
self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'],
|
|
betas=self.opt_train['G_optimizer_betas'],
|
|
weight_decay=self.opt_train['G_optimizer_wd'])
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# ----------------------------------------
|
|
# define scheduler, only "MultiStepLR"
|
|
# ----------------------------------------
|
|
def define_scheduler(self):
|
|
if self.opt_train['G_scheduler_type'] == 'MultiStepLR':
|
|
self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer,
|
|
self.opt_train['G_scheduler_milestones'],
|
|
self.opt_train['G_scheduler_gamma']
|
|
))
|
|
elif self.opt_train['G_scheduler_type'] == 'CosineAnnealingWarmRestarts':
|
|
self.schedulers.append(lr_scheduler.CosineAnnealingWarmRestarts(self.G_optimizer,
|
|
self.opt_train['G_scheduler_periods'],
|
|
self.opt_train['G_scheduler_restart_weights'],
|
|
self.opt_train['G_scheduler_eta_min']
|
|
))
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Optimization during training with data
|
|
# Testing/evaluation
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
# ----------------------------------------
|
|
# feed L/H data
|
|
# ----------------------------------------
|
|
def feed_data(self, data, need_H=True):
|
|
self.L = data['L'].to(self.device)
|
|
if need_H:
|
|
self.H = data['H'].to(self.device)
|
|
|
|
# ----------------------------------------
|
|
# feed L to netG
|
|
# ----------------------------------------
|
|
def netG_forward(self):
|
|
self.E = self.netG(self.L)
|
|
|
|
# ----------------------------------------
|
|
# update parameters and get loss
|
|
# ----------------------------------------
|
|
def optimize_parameters(self, current_step):
|
|
self.G_optimizer.zero_grad()
|
|
self.netG_forward()
|
|
G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)
|
|
G_loss.backward()
|
|
|
|
# ------------------------------------
|
|
# clip_grad
|
|
# ------------------------------------
|
|
# `clip_grad_norm` helps prevent the exploding gradient problem.
|
|
G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0
|
|
if G_optimizer_clipgrad > 0:
|
|
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2)
|
|
|
|
self.G_optimizer.step()
|
|
|
|
# ------------------------------------
|
|
# regularizer
|
|
# ------------------------------------
|
|
G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0
|
|
if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
|
|
self.netG.apply(regularizer_orth)
|
|
G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0
|
|
if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
|
|
self.netG.apply(regularizer_clip)
|
|
|
|
# self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0] # if `reduction='sum'`
|
|
self.log_dict['G_loss'] = G_loss.item()
|
|
|
|
if self.opt_train['E_decay'] > 0:
|
|
self.update_E(self.opt_train['E_decay'])
|
|
|
|
# ----------------------------------------
|
|
# test / inference
|
|
# ----------------------------------------
|
|
def test(self):
|
|
self.netG.eval()
|
|
with torch.no_grad():
|
|
self.netG_forward()
|
|
self.netG.train()
|
|
|
|
# ----------------------------------------
|
|
# test / inference x8
|
|
# ----------------------------------------
|
|
def testx8(self):
|
|
self.netG.eval()
|
|
with torch.no_grad():
|
|
self.E = test_mode(self.netG, self.L, mode=3, sf=self.opt['scale'], modulo=1)
|
|
self.netG.train()
|
|
|
|
# ----------------------------------------
|
|
# get log_dict
|
|
# ----------------------------------------
|
|
def current_log(self):
|
|
return self.log_dict
|
|
|
|
# ----------------------------------------
|
|
# get L, E, H image
|
|
# ----------------------------------------
|
|
def current_visuals(self, need_H=True):
|
|
out_dict = OrderedDict()
|
|
out_dict['L'] = self.L.detach()[0].float().cpu()
|
|
out_dict['E'] = self.E.detach()[0].float().cpu()
|
|
if need_H:
|
|
out_dict['H'] = self.H.detach()[0].float().cpu()
|
|
return out_dict
|
|
|
|
# ----------------------------------------
|
|
# get L, E, H batch images
|
|
# ----------------------------------------
|
|
def current_results(self, need_H=True):
|
|
out_dict = OrderedDict()
|
|
out_dict['L'] = self.L.detach().float().cpu()
|
|
out_dict['E'] = self.E.detach().float().cpu()
|
|
if need_H:
|
|
out_dict['H'] = self.H.detach().float().cpu()
|
|
return out_dict
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Information of netG
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
# ----------------------------------------
|
|
# print network
|
|
# ----------------------------------------
|
|
def print_network(self):
|
|
msg = self.describe_network(self.netG)
|
|
print(msg)
|
|
|
|
# ----------------------------------------
|
|
# print params
|
|
# ----------------------------------------
|
|
def print_params(self):
|
|
msg = self.describe_params(self.netG)
|
|
print(msg)
|
|
|
|
# ----------------------------------------
|
|
# network information
|
|
# ----------------------------------------
|
|
def info_network(self):
|
|
msg = self.describe_network(self.netG)
|
|
return msg
|
|
|
|
# ----------------------------------------
|
|
# params information
|
|
# ----------------------------------------
|
|
def info_params(self):
|
|
msg = self.describe_params(self.netG)
|
|
return msg
|