deep-tempest/end-to-end/utils/utils_matconvnet.py

198 lines
6.8 KiB
Python

# -*- coding: utf-8 -*-
import numpy as np
import torch
from collections import OrderedDict
# import scipy.io as io
import hdf5storage
"""
# --------------------------------------------
# Convert matconvnet SimpleNN model into pytorch model
# --------------------------------------------
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# 28/Nov/2019
# --------------------------------------------
"""
def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
"""Modified version of https://github.com/albanie/pytorch-mcn
Adjust memory layout and load weights as torch tensor
Args:
x (ndaray): a numpy array, corresponding to a set of network weights
stored in column major order
squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
singletons from the trailing dimensions. So after converting to
pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
it will be reshaped to a matrix with shape (A,B).
in_features (int :: None): used to reshape weights for a linear block.
out_features (int :: None): used to reshape weights for a linear block.
Returns:
torch.tensor: a permuted sets of weights, matching the pytorch layout
convention
"""
if x.ndim == 4:
x = x.transpose((3, 2, 0, 1))
# for FFDNet, pixel-shuffle layer
# if x.shape[1]==13:
# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:]
# if x.shape[0]==12:
# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
# if x.shape[1]==5:
# x=x[:,[0,2,1,3, 4],:,:]
# if x.shape[0]==4:
# x=x[[0,2,1,3],:,:,:]
## for SRMD, pixel-shuffle layer
# if x.shape[0]==12:
# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
# if x.shape[0]==27:
# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
# if x.shape[0]==48:
# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
elif x.ndim == 3: # add by Kai
x = x[:,:,:,None]
x = x.transpose((3, 2, 0, 1))
elif x.ndim == 2:
if x.shape[1] == 1:
x = x.flatten()
if squeeze:
if in_features and out_features:
x = x.reshape((out_features, in_features))
x = np.squeeze(x)
return torch.from_numpy(np.ascontiguousarray(x))
def save_model(network, save_path):
state_dict = network.state_dict()
for key, param in state_dict.items():
state_dict[key] = param.cpu()
torch.save(state_dict, save_path)
if __name__ == '__main__':
# from utils import utils_logger
# import logging
# utils_logger.logger_info('a', 'a.log')
# logger = logging.getLogger('a')
#
# mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
mcn = hdf5storage.loadmat('models/modelcolor.mat')
#logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
mat_net = OrderedDict()
for idx in range(25):
mat_net[str(idx)] = OrderedDict()
count = -1
print(idx)
for i in range(13):
if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
count += 1
w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
# print(w.shape)
w = weights2tensor(w)
# print(w.shape)
b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
b = weights2tensor(b)
print(b.shape)
mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
torch.save(mat_net, 'model_zoo/modelcolor.pth')
# from models.network_dncnn import IRCNN as net
# network = net(in_nc=3, out_nc=3, nc=64)
# state_dict = network.state_dict()
#
# #show_kv(state_dict)
#
# for i in range(len(mcn['net'][0][0][0])):
# print(mcn['net'][0][0][0][i][0][0][0][0])
#
# count = -1
# mat_net = OrderedDict()
# for i in range(len(mcn['net'][0][0][0])):
# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
#
# count += 1
# w = mcn['net'][0][0][0][i][0][1][0][0]
# print(w.shape)
# w = weights2tensor(w)
# print(w.shape)
#
# b = mcn['net'][0][0][0][i][0][1][0][1]
# b = weights2tensor(b)
# print(b.shape)
#
# mat_net['model.{:d}.weight'.format(count*2)] = w
# mat_net['model.{:d}.bias'.format(count*2)] = b
#
# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
#
#
#
# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
# def show_kv(net):
# for k, v in net.items():
# print(k)
#
# show_kv(crt_net)
# from models.network_dncnn import DnCNN as net
# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
# from models.network_srmd import SRMD as net
# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
#
# from models.network_rrdb import RRDB as net
# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
#
# state_dict = network.state_dict()
# for key, param in state_dict.items():
# print(key)
# from models.network_imdn import IMDN as net
# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
# state_dict = network.state_dict()
# mat_net = OrderedDict()
# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
# mat_net[key] = param2
# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth')
#
# net_old = torch.load('net_old.pth')
# def show_kv(net):
# for k, v in net.items():
# print(k)
#
# show_kv(net_old)
# from models.network_dpsr import MSRResNet_prior as net
# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
# state_dict = network.state_dict()
# net_new = OrderedDict()
# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
# net_new[key] = param_old
# torch.save(net_new, 'net_new.pth')
# print(key)
# print(param.size())
# run utils/utils_matconvnet.py