225 lines
7.9 KiB
Python
225 lines
7.9 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
from utils.utils_bnorm import merge_bn, tidy_sequential
|
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|
|
|
|
|
class ModelBase():
|
|
def __init__(self, opt):
|
|
self.opt = opt # opt
|
|
self.save_dir = opt['path']['models'] # save models
|
|
# self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
|
self.device = torch.device('cuda' if len(opt['gpu_ids']) != 0 else 'cpu')
|
|
self.is_train = opt['is_train'] # training or not
|
|
self.schedulers = [] # schedulers
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Preparation before training with data
|
|
# Save model during training
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
def init_train(self):
|
|
pass
|
|
|
|
def load(self):
|
|
pass
|
|
|
|
def save(self, label):
|
|
pass
|
|
|
|
def define_loss(self):
|
|
pass
|
|
|
|
def define_optimizer(self):
|
|
pass
|
|
|
|
def define_scheduler(self):
|
|
pass
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Optimization during training with data
|
|
# Testing/evaluation
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
def feed_data(self, data):
|
|
pass
|
|
|
|
def optimize_parameters(self):
|
|
pass
|
|
|
|
def current_visuals(self):
|
|
pass
|
|
|
|
def current_losses(self):
|
|
pass
|
|
|
|
def update_learning_rate(self, n):
|
|
for scheduler in self.schedulers:
|
|
scheduler.step(n)
|
|
|
|
def current_learning_rate(self):
|
|
# return self.schedulers[0].get_lr()[0]
|
|
return self.schedulers[0].get_lr()[-1] # Above lr seems always fixed
|
|
|
|
def requires_grad(self, model, flag=True):
|
|
for p in model.parameters():
|
|
p.requires_grad = flag
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Information of net
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
def print_network(self):
|
|
pass
|
|
|
|
def info_network(self):
|
|
pass
|
|
|
|
def print_params(self):
|
|
pass
|
|
|
|
def info_params(self):
|
|
pass
|
|
|
|
def get_bare_model(self, network):
|
|
"""Get bare model, especially under wrapping with
|
|
DistributedDataParallel or DataParallel.
|
|
"""
|
|
if isinstance(network, (DataParallel, DistributedDataParallel)):
|
|
network = network.module
|
|
return network
|
|
|
|
def model_to_device(self, network):
|
|
"""Model to device. It also warps models with DistributedDataParallel
|
|
or DataParallel.
|
|
Args:
|
|
network (nn.Module)
|
|
"""
|
|
network = network.to(self.device)
|
|
if self.opt['dist']:
|
|
find_unused_parameters = self.opt.get('find_unused_parameters', True)
|
|
use_static_graph = self.opt.get('use_static_graph', False)
|
|
network = DistributedDataParallel(network, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
|
|
if use_static_graph:
|
|
print('Using static graph. Make sure that "unused parameters" will not change during training loop.')
|
|
network._set_static_graph()
|
|
else:
|
|
network = DataParallel(network)
|
|
return network
|
|
|
|
# ----------------------------------------
|
|
# network name and number of parameters
|
|
# ----------------------------------------
|
|
def describe_network(self, network):
|
|
network = self.get_bare_model(network)
|
|
msg = '\n'
|
|
msg += 'Networks name: {}'.format(network.__class__.__name__) + '\n'
|
|
msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), network.parameters()))) + '\n'
|
|
msg += 'Learning params number: {}'.format(sum(map(lambda x: (x.numel())*x.requires_grad, network.parameters()))) + '\n'
|
|
msg += 'Net structure:\n{}'.format(str(network)) + '\n'
|
|
return msg
|
|
|
|
# ----------------------------------------
|
|
# parameters description
|
|
# ----------------------------------------
|
|
def describe_params(self, network):
|
|
network = self.get_bare_model(network)
|
|
msg = '\n'
|
|
msg += '| {}\t| {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('Train','mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
|
|
for name, param in network.named_parameters():
|
|
if not 'num_batches_tracked' in name:
|
|
is_trainable = param.requires_grad
|
|
v = param.data.clone().float()
|
|
msg += '| {}\t| {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(is_trainable, v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
|
|
return msg
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Save prameters
|
|
# Load prameters
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
# ----------------------------------------
|
|
# save the state_dict of the network
|
|
# ----------------------------------------
|
|
def save_network(self, save_dir, network, network_label, iter_label):
|
|
save_filename = '{}_{}.pth'.format(iter_label, network_label)
|
|
save_path = os.path.join(save_dir, save_filename)
|
|
network = self.get_bare_model(network)
|
|
state_dict = network.state_dict()
|
|
for key, param in state_dict.items():
|
|
state_dict[key] = param.cpu()
|
|
torch.save(state_dict, save_path)
|
|
|
|
# ----------------------------------------
|
|
# load the state_dict of the network
|
|
# ----------------------------------------
|
|
def load_network(self, load_path, network, strict=True, param_key='params'):
|
|
network = self.get_bare_model(network)
|
|
if strict:
|
|
state_dict = torch.load(load_path)
|
|
if param_key in state_dict.keys():
|
|
state_dict = state_dict[param_key]
|
|
network.load_state_dict(state_dict, strict=strict)
|
|
else:
|
|
state_dict_old = torch.load(load_path)
|
|
if param_key in state_dict_old.keys():
|
|
state_dict_old = state_dict_old[param_key]
|
|
state_dict = network.state_dict()
|
|
for ((key_old, param_old),(key, param)) in zip(state_dict_old.items(), state_dict.items()):
|
|
state_dict[key] = param_old
|
|
network.load_state_dict(state_dict, strict=True)
|
|
del state_dict_old, state_dict
|
|
|
|
# ----------------------------------------
|
|
# save the state_dict of the optimizer
|
|
# ----------------------------------------
|
|
def save_optimizer(self, save_dir, optimizer, optimizer_label, iter_label):
|
|
save_filename = '{}_{}.pth'.format(iter_label, optimizer_label)
|
|
save_path = os.path.join(save_dir, save_filename)
|
|
torch.save(optimizer.state_dict(), save_path)
|
|
|
|
# ----------------------------------------
|
|
# load the state_dict of the optimizer
|
|
# ----------------------------------------
|
|
def load_optimizer(self, load_path, optimizer):
|
|
optimizer.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device())))
|
|
|
|
def update_E(self, decay=0.999):
|
|
netG = self.get_bare_model(self.netG)
|
|
netG_params = dict(netG.named_parameters())
|
|
netE_params = dict(self.netE.named_parameters())
|
|
for k in netG_params.keys():
|
|
netE_params[k].data.mul_(decay).add_(netG_params[k].data, alpha=1-decay)
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Merge Batch Normalization for training
|
|
# Merge Batch Normalization for testing
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
# ----------------------------------------
|
|
# merge bn during training
|
|
# ----------------------------------------
|
|
def merge_bnorm_train(self):
|
|
merge_bn(self.netG)
|
|
tidy_sequential(self.netG)
|
|
self.define_optimizer()
|
|
self.define_scheduler()
|
|
|
|
# ----------------------------------------
|
|
# merge bn before testing
|
|
# ----------------------------------------
|
|
def merge_bnorm_test(self):
|
|
merge_bn(self.netG)
|
|
tidy_sequential(self.netG)
|