Renamed "KAIR" folder to "End-to-End"
This commit is contained in:
parent
f2d68d8994
commit
b6e708fe88
File diff suppressed because it is too large
Load Diff
|
@ -1,235 +0,0 @@
|
|||
import os.path
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from collections import OrderedDict
|
||||
# from scipy.io import loadmat
|
||||
|
||||
import torch
|
||||
|
||||
from utils import utils_logger
|
||||
from utils import utils_model
|
||||
from utils import utils_image as util
|
||||
|
||||
from utils import utils_option as option
|
||||
from models.select_model import define_Model
|
||||
|
||||
|
||||
'''
|
||||
Spyder (Python 3.6)
|
||||
PyTorch 1.1.0
|
||||
Windows 10 or Linux
|
||||
|
||||
Kai Zhang (cskaizhang@gmail.com)
|
||||
github: https://github.com/cszn/KAIR
|
||||
https://github.com/cszn/DnCNN
|
||||
|
||||
@article{zhang2017beyond,
|
||||
title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
|
||||
journal={IEEE Transactions on Image Processing},
|
||||
volume={26},
|
||||
number={7},
|
||||
pages={3142--3155},
|
||||
year={2017},
|
||||
publisher={IEEE}
|
||||
}
|
||||
|
||||
% If you have any question, please feel free to contact with me.
|
||||
% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn)
|
||||
|
||||
by Kai Zhang (12/Dec./2019)
|
||||
'''
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
|--model_zoo # model_zoo
|
||||
|--dncnn_15 # model_name
|
||||
|--dncnn_25
|
||||
|--dncnn_50
|
||||
|--dncnn_gray_blind
|
||||
|--dncnn_color_blind
|
||||
|--dncnn3
|
||||
|--testset # testsets
|
||||
|--set12 # testset_name
|
||||
|--bsd68
|
||||
|--cbsd68
|
||||
|--results # results
|
||||
|--set12_dncnn_15 # result_name = testset_name + '_' + model_name
|
||||
|--set12_dncnn_25
|
||||
|--bsd68_dncnn_15
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def main(json_path='options/train_drunet.json'):
|
||||
|
||||
# ----------------------------------------
|
||||
# Preparation
|
||||
# ----------------------------------------
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--testsetH_name', type=str, default='web_images_test', help='test set, bsd68 | set12')
|
||||
parser.add_argument('--testsetL_name', type=str, default='simulations', help='test set, bsd68 | set12')
|
||||
parser.add_argument('--noise_level_img', type=int, default=15, help='noise level: 15, 25, 50')
|
||||
parser.add_argument('--x8', type=bool, default=False, help='x8 to boost performance')
|
||||
parser.add_argument('--show_img', type=bool, default=False, help='show the image')
|
||||
parser.add_argument('--model_pool', type=str, default='model_zoo', help='path of model_zoo')
|
||||
parser.add_argument('--testsets', type=str, default='testsets', help='path of testing folder')
|
||||
parser.add_argument('--results', type=str, default='results', help='path of results')
|
||||
parser.add_argument('--need_degradation', type=bool, default=True, help='add noise or not')
|
||||
parser.add_argument('--task_current', type=str, default='dn', help='dn for denoising, fixed!')
|
||||
parser.add_argument('--sf', type=int, default=1, help='unused for denoising')
|
||||
parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.')
|
||||
args = parser.parse_args()
|
||||
'''
|
||||
# ----------------------------------------
|
||||
# Step--1 (prepare opt)
|
||||
# ----------------------------------------
|
||||
'''
|
||||
|
||||
opt = option.parse(args.opt, is_train=False)
|
||||
|
||||
border = 0
|
||||
n_channels = 3
|
||||
# --<--<--<--<--<--<--<--<--<--<--<--<--<-
|
||||
|
||||
# ----------------------------------------
|
||||
# return None for missing key
|
||||
# ----------------------------------------
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
# if 'color' in args.model_name:
|
||||
# n_channels = 3 # fixed, 1 for grayscale image, 3 for color image
|
||||
# else:
|
||||
# n_channels = 1 # fixed for grayscale image
|
||||
# if args.model_name in ['dncnn_gray_blind', 'dncnn_color_blind', 'dncnn3']:
|
||||
# nb = 20 # fixed
|
||||
# else:
|
||||
# nb = 17 # fixed
|
||||
|
||||
model_path = opt['path']['pretrained_netG']
|
||||
pretrain_name = model_path.split('/')[-1].split('.')[0]
|
||||
model_name = 'DnCNN_'+pretrain_name
|
||||
result_name = f'eval_std{args.noise_level_img}_{model_name}' # fixed
|
||||
border = args.sf if args.task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM
|
||||
# model_path = os.path.join(args.model_pool, args.model_name+'.pth')
|
||||
|
||||
# ----------------------------------------
|
||||
# L_path, E_path, H_path
|
||||
# ----------------------------------------
|
||||
|
||||
L_path = os.path.join(args.testsets, args.testsetL_name) # L_path, for Low-quality images
|
||||
H_path = os.path.join(args.testsets, args.testsetH_name) # H_path, for High-quality images
|
||||
E_path = os.path.join(args.results, result_name) # E_path, for Estimated images
|
||||
util.mkdir(E_path)
|
||||
|
||||
if H_path == L_path:
|
||||
args.need_degradation = True
|
||||
logger_name = result_name
|
||||
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
|
||||
logger = logging.getLogger(logger_name)
|
||||
|
||||
need_H = True if H_path is not None else False
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# ----------------------------------------
|
||||
# load model
|
||||
# ----------------------------------------
|
||||
|
||||
# from models.network_dncnn import DnCNN as net
|
||||
# model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=nb, act_mode='R')
|
||||
# # model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=nb, act_mode='BR') # use this if BN is not merged by utils_bnorm.merge_bn(model)
|
||||
# model.load_state_dict(torch.load(model_path), strict=True)
|
||||
|
||||
model = define_Model(opt)
|
||||
model.load()
|
||||
|
||||
# model.eval()
|
||||
# for k, v in model.named_parameters():
|
||||
# v.requires_grad = False
|
||||
# model = model.to(device)
|
||||
# logger.info('Model path: {:s}'.format(model_path))
|
||||
# number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
|
||||
# logger.info('Params number: {}'.format(number_parameters))
|
||||
|
||||
test_results = OrderedDict()
|
||||
test_results['psnr'] = []
|
||||
test_results['ssim'] = []
|
||||
|
||||
logger.info('model_name:{}, image sigma:{}'.format(model_name, args.noise_level_img))
|
||||
logger.info(L_path)
|
||||
L_paths = util.get_image_paths(L_path)
|
||||
H_paths = util.get_image_paths(H_path) if need_H else None
|
||||
|
||||
for idx, img in enumerate(L_paths):
|
||||
|
||||
# ------------------------------------
|
||||
# (1) img_L
|
||||
# ------------------------------------
|
||||
|
||||
img_name, ext = os.path.splitext(os.path.basename(img))
|
||||
# logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
|
||||
img_L = util.imread_uint(img, n_channels=n_channels)[:,:,:2]
|
||||
img_L = util.uint2single(img_L)
|
||||
|
||||
if args.need_degradation: # degradation process
|
||||
np.random.seed(seed=0) # for reproducibility
|
||||
img_L += np.random.normal(0, args.noise_level_img/255., img_L.shape)
|
||||
|
||||
util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(args.noise_level_img)) if args.show_img else None
|
||||
|
||||
img_L = util.single2tensor4(img_L)
|
||||
img_L = img_L.to(device)
|
||||
|
||||
# ------------------------------------
|
||||
# (2) img_E
|
||||
# ------------------------------------
|
||||
|
||||
if not args.x8:
|
||||
model.feed_data({'L':img_L}, need_H=False)
|
||||
model.test()
|
||||
|
||||
visuals = model.current_visuals(need_H=False)
|
||||
img_E = visuals['E']
|
||||
else:
|
||||
img_E = utils_model.test_mode(model, img_L, mode=3)
|
||||
|
||||
img_E = util.tensor2uint(img_E)[:,:,0]
|
||||
|
||||
if need_H:
|
||||
|
||||
# --------------------------------
|
||||
# (3) img_H
|
||||
# --------------------------------
|
||||
|
||||
img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
|
||||
img_H = np.mean(img_H, axis=2)
|
||||
img_H = img_H.squeeze()
|
||||
|
||||
# --------------------------------
|
||||
# PSNR and SSIM
|
||||
# --------------------------------
|
||||
|
||||
psnr = util.calculate_psnr(img_E, img_H, border=border)
|
||||
ssim = util.calculate_ssim(img_E, img_H, border=border)
|
||||
test_results['psnr'].append(psnr)
|
||||
test_results['ssim'].append(ssim)
|
||||
logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))
|
||||
util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if args.show_img else None
|
||||
|
||||
# ------------------------------------
|
||||
# save results
|
||||
# ------------------------------------
|
||||
|
||||
util.imsave(img_E, os.path.join(E_path, img_name+ext))
|
||||
|
||||
if need_H:
|
||||
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
|
||||
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
|
||||
logger.info('Average PSNR/SSIM - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
main()
|
|
@ -1,225 +0,0 @@
|
|||
import os.path
|
||||
import math
|
||||
import argparse
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch
|
||||
|
||||
from utils import utils_logger
|
||||
from utils import utils_image as util
|
||||
from utils import utils_option as option
|
||||
from utils.utils_dist import get_dist_info, init_dist
|
||||
|
||||
from data.select_dataset import define_Dataset
|
||||
from models.select_model import define_Model
|
||||
|
||||
|
||||
def save_deeptempest_result(imgs_dict, save_path):
|
||||
|
||||
# Name of original image and saving path
|
||||
file_name = imgs_dict['image_name']
|
||||
save_img_path = os.path.join(save_path, file_name)
|
||||
|
||||
# High resolution image
|
||||
H = imgs_dict['H_vis']
|
||||
util.imsave(H, save_img_path+'_H.png')
|
||||
|
||||
# Low resolution image (real/imaginary part and module)
|
||||
L = util.tensor2uint(imgs_dict['L'])
|
||||
L_complex = np.pad(L,((0,0),(0,0),(0,1))) # Real and imaginary
|
||||
util.imsave(L_complex, save_img_path+'_L_complex.png')
|
||||
|
||||
L = L.astype('float')
|
||||
L = np.abs(L[:,:,0] + 1j*L[:,:,1])
|
||||
L = 255*(L - L.min())/(L.max() - L.min()) # Module
|
||||
util.imsave(L, save_img_path+'_L_module.png')
|
||||
|
||||
E = imgs_dict['E_vis']
|
||||
|
||||
util.imsave(E, save_img_path+'_E.png')
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# training code for DRUNet
|
||||
# --------------------------------------------
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# github: https://github.com/cszn/KAIR
|
||||
'''
|
||||
|
||||
|
||||
def drunet_pipeline(json_path='options/train_drunet.json'):
|
||||
|
||||
test_imgs_list = []
|
||||
|
||||
'''
|
||||
# ----------------------------------------
|
||||
# Step--1 (prepare opt)
|
||||
# ----------------------------------------
|
||||
'''
|
||||
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.')
|
||||
# parser.add_argument('--launcher', default='pytorch', help='job launcher')
|
||||
# parser.add_argument('--local_rank', type=int, default=0)
|
||||
# parser.add_argument('--dist', default=False)
|
||||
|
||||
opt = option.parse(json_path, is_train=True)
|
||||
# opt['dist'] = parser.parse_args().dist
|
||||
|
||||
# ----------------------------------------
|
||||
# distributed settings
|
||||
# ----------------------------------------
|
||||
if opt['dist']:
|
||||
init_dist('pytorch')
|
||||
opt['rank'], opt['world_size'] = get_dist_info()
|
||||
|
||||
# if opt['rank'] == 0:
|
||||
# util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key))
|
||||
|
||||
# ----------------------------------------
|
||||
# update opt
|
||||
# ----------------------------------------
|
||||
# -->-->-->-->-->-->-->-->-->-->-->-->-->-
|
||||
|
||||
init_iter_G, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G')
|
||||
# init_iter_G, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G', pretrained_path = opt['path']['pretrained_netG'])
|
||||
opt['path']['pretrained_netG'] = init_path_G
|
||||
init_iter_optimizerG, init_path_optimizerG = option.find_last_checkpoint(opt['path']['models'], net_type='optimizerG')
|
||||
opt['path']['pretrained_optimizerG'] = init_path_optimizerG
|
||||
current_step = max(init_iter_G, init_iter_optimizerG)
|
||||
|
||||
border = opt['scale']
|
||||
# --<--<--<--<--<--<--<--<--<--<--<--<--<-
|
||||
|
||||
# ----------------------------------------
|
||||
# save opt to a '../option.json' file
|
||||
# ----------------------------------------
|
||||
# if opt['rank'] == 0:
|
||||
# option.save(opt)
|
||||
|
||||
# ----------------------------------------
|
||||
# return None for missing key
|
||||
# ----------------------------------------
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
# ----------------------------------------
|
||||
# configure logger
|
||||
# ----------------------------------------
|
||||
if opt['rank'] == 0:
|
||||
logger_name = 'train'
|
||||
utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log'))
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.info(option.dict2str(opt))
|
||||
|
||||
# ----------------------------------------
|
||||
# seed
|
||||
# ----------------------------------------
|
||||
seed = opt['train']['manual_seed']
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
print('Random seed: {}'.format(seed))
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
'''
|
||||
# ----------------------------------------
|
||||
# Step--2 (creat dataloader)
|
||||
# ----------------------------------------
|
||||
'''
|
||||
|
||||
# ----------------------------------------
|
||||
# 1) create_dataset
|
||||
# 2) creat_dataloader for train and test
|
||||
# ----------------------------------------
|
||||
for phase, dataset_opt in opt['datasets'].items():
|
||||
if phase == 'test':
|
||||
test_set = define_Dataset(dataset_opt)
|
||||
test_loader = DataLoader(test_set, batch_size=1,
|
||||
shuffle=False, num_workers=1,
|
||||
drop_last=False, pin_memory=True)
|
||||
|
||||
'''
|
||||
# ----------------------------------------
|
||||
# Step--3 (initialize model)
|
||||
# ----------------------------------------
|
||||
'''
|
||||
|
||||
model = define_Model(opt)
|
||||
model.init_train()
|
||||
# if opt['rank'] == 0:
|
||||
# logger.info(model.info_network())
|
||||
# logger.info(model.info_params())
|
||||
|
||||
'''
|
||||
# ----------------------------------------
|
||||
# Step--4 (main test)
|
||||
# ----------------------------------------
|
||||
'''
|
||||
|
||||
avg_psnr = 0.0
|
||||
avg_loss = 0.0
|
||||
idx = 0
|
||||
|
||||
for j, test_data in enumerate(test_loader):
|
||||
idx += 1
|
||||
image_name_ext = os.path.basename(test_data['L_path'][0])
|
||||
img_name, ext = os.path.splitext(image_name_ext)
|
||||
|
||||
# img_dir = os.path.join(opt['path']['images'], img_name)
|
||||
# util.mkdir(img_dir)
|
||||
|
||||
model.feed_data(test_data)
|
||||
model.test()
|
||||
|
||||
visuals = model.current_visuals()
|
||||
E_img = util.tensor2uint(visuals['E'])
|
||||
H_img = util.tensor2uint(visuals['H'])
|
||||
|
||||
test_imgs_list.append({'L':test_data['L'],
|
||||
'H':test_data['H'],
|
||||
'H_vis': H_img,
|
||||
'E_vis': E_img,
|
||||
'image_name': img_name
|
||||
})
|
||||
|
||||
# -----------------------
|
||||
# save estimated image E
|
||||
# -----------------------
|
||||
# save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
|
||||
# util.imsave(E_img, save_img_path)
|
||||
|
||||
# -----------------------
|
||||
# calculate PSNR
|
||||
# -----------------------
|
||||
current_psnr = util.calculate_psnr(E_img, H_img, border=border)
|
||||
# -----------------------
|
||||
# calculate loss
|
||||
# -----------------------
|
||||
current_loss = model.G_lossfn_weight * model.G_lossfn(model.E, model.H)
|
||||
|
||||
logger.info('{:->4d}--> {:>10s} | PSNR = {:<4.2f}dB ; G_loss = {:.3e}'.format(idx, image_name_ext, current_psnr, current_loss))
|
||||
|
||||
avg_psnr += current_psnr
|
||||
avg_loss += current_loss
|
||||
|
||||
avg_psnr = avg_psnr / idx
|
||||
avg_loss = avg_loss / idx
|
||||
|
||||
# testing log
|
||||
logger.info('Average PSNR : {:<.2f}dB, Average loss : {:.3e}\n'.format(avg_psnr, avg_loss))
|
||||
|
||||
return test_imgs_list
|
||||
|
||||
if __name__ == '__main__':
|
||||
imgs = drunet_pipeline()
|
||||
save_path = 'testsets/web_subset/visuals' # COMPLETE SAVING PATH
|
||||
for img in imgs:
|
||||
save_deeptempest_result(img, save_path)
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
lista_dir = os.listdir('./')
|
||||
|
||||
h_total = 1600
|
||||
v_total = 900
|
||||
|
||||
h_image = 1024
|
||||
v_image = 768
|
||||
|
||||
I_pad = 255*np.ones((v_total,h_total,3),dtype=np.uint8)
|
||||
|
||||
for file in lista_dir:
|
||||
|
||||
file_ext = file.split('.')[-1]
|
||||
file_name = file.split('.'+file_ext)[0]
|
||||
|
||||
if file_ext != 'jpg':
|
||||
continue
|
||||
|
||||
I = Image.open(file)
|
||||
I = np.array(I)
|
||||
|
||||
I_pad[(v_total-v_image)//2:-(v_total-v_image)//2, (h_total-h_image)//2:-(h_total-h_image)//2, :] = I
|
||||
|
||||
Image.fromarray(I_pad).save(f'./Roboflow_dataset_{file_name}.png')
|
|
@ -1,34 +0,0 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
lista_dir = os.listdir('./')
|
||||
|
||||
h_total = 1600
|
||||
v_total = 900
|
||||
|
||||
h_image = 1280
|
||||
v_image = 1280
|
||||
|
||||
I_pad = 255*np.ones((v_total,h_total,3),dtype=np.uint8)
|
||||
|
||||
for i, file in enumerate(lista_dir):
|
||||
|
||||
file_ext = file.split('.')[-1]
|
||||
file_name = file.split('.'+file_ext)[0]
|
||||
|
||||
if file_ext!= 'png':
|
||||
continue
|
||||
|
||||
I = Image.open(file)
|
||||
I = np.array(I)[:,:,:3]
|
||||
|
||||
if I.shape[:2] != (1280,1280):
|
||||
continue
|
||||
|
||||
if i%2==0:
|
||||
I_pad[:, (h_total-h_image)//2:-(h_total-h_image)//2, :] = I[:v_total,:,:]
|
||||
else:
|
||||
I_pad[:, (h_total-h_image)//2:-(h_total-h_image)//2, :] = I[v_image-v_total:,:,:]
|
||||
|
||||
Image.fromarray(I_pad).save(f'./{file_name}.png')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
Before Width: | Height: | Size: 192 KiB |
|
@ -1,263 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Mon Apr 24 15:43:29 2017
|
||||
@author: zhaoy
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage import transform as trans
|
||||
|
||||
# reference facial points, a list of coordinates (x,y)
|
||||
REFERENCE_FACIAL_POINTS = [
|
||||
[30.29459953, 51.69630051],
|
||||
[65.53179932, 51.50139999],
|
||||
[48.02519989, 71.73660278],
|
||||
[33.54930115, 92.3655014],
|
||||
[62.72990036, 92.20410156]
|
||||
]
|
||||
|
||||
DEFAULT_CROP_SIZE = (96, 112)
|
||||
|
||||
|
||||
def _umeyama(src, dst, estimate_scale=True, scale=1.0):
|
||||
"""Estimate N-D similarity transformation with or without scaling.
|
||||
Parameters
|
||||
----------
|
||||
src : (M, N) array
|
||||
Source coordinates.
|
||||
dst : (M, N) array
|
||||
Destination coordinates.
|
||||
estimate_scale : bool
|
||||
Whether to estimate scaling factor.
|
||||
Returns
|
||||
-------
|
||||
T : (N + 1, N + 1)
|
||||
The homogeneous similarity transformation matrix. The matrix contains
|
||||
NaN values only if the problem is not well-conditioned.
|
||||
References
|
||||
----------
|
||||
.. [1] "Least-squares estimation of transformation parameters between two
|
||||
point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
|
||||
"""
|
||||
|
||||
num = src.shape[0]
|
||||
dim = src.shape[1]
|
||||
|
||||
# Compute mean of src and dst.
|
||||
src_mean = src.mean(axis=0)
|
||||
dst_mean = dst.mean(axis=0)
|
||||
|
||||
# Subtract mean from src and dst.
|
||||
src_demean = src - src_mean
|
||||
dst_demean = dst - dst_mean
|
||||
|
||||
# Eq. (38).
|
||||
A = dst_demean.T @ src_demean / num
|
||||
|
||||
# Eq. (39).
|
||||
d = np.ones((dim,), dtype=np.double)
|
||||
if np.linalg.det(A) < 0:
|
||||
d[dim - 1] = -1
|
||||
|
||||
T = np.eye(dim + 1, dtype=np.double)
|
||||
|
||||
U, S, V = np.linalg.svd(A)
|
||||
|
||||
# Eq. (40) and (43).
|
||||
rank = np.linalg.matrix_rank(A)
|
||||
if rank == 0:
|
||||
return np.nan * T
|
||||
elif rank == dim - 1:
|
||||
if np.linalg.det(U) * np.linalg.det(V) > 0:
|
||||
T[:dim, :dim] = U @ V
|
||||
else:
|
||||
s = d[dim - 1]
|
||||
d[dim - 1] = -1
|
||||
T[:dim, :dim] = U @ np.diag(d) @ V
|
||||
d[dim - 1] = s
|
||||
else:
|
||||
T[:dim, :dim] = U @ np.diag(d) @ V
|
||||
|
||||
if estimate_scale:
|
||||
# Eq. (41) and (42).
|
||||
scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
|
||||
else:
|
||||
scale = scale
|
||||
|
||||
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
|
||||
T[:dim, :dim] *= scale
|
||||
|
||||
return T, scale
|
||||
|
||||
|
||||
class FaceWarpException(Exception):
|
||||
def __str__(self):
|
||||
return 'In File {}:{}'.format(
|
||||
__file__, super.__str__(self))
|
||||
|
||||
|
||||
def get_reference_facial_points(output_size=None,
|
||||
inner_padding_factor=0.0,
|
||||
outer_padding=(0, 0),
|
||||
default_square=False):
|
||||
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
|
||||
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
|
||||
|
||||
# 0) make the inner region a square
|
||||
if default_square:
|
||||
size_diff = max(tmp_crop_size) - tmp_crop_size
|
||||
tmp_5pts += size_diff / 2
|
||||
tmp_crop_size += size_diff
|
||||
|
||||
if (output_size and
|
||||
output_size[0] == tmp_crop_size[0] and
|
||||
output_size[1] == tmp_crop_size[1]):
|
||||
print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
|
||||
return tmp_5pts
|
||||
|
||||
if (inner_padding_factor == 0 and
|
||||
outer_padding == (0, 0)):
|
||||
if output_size is None:
|
||||
print('No paddings to do: return default reference points')
|
||||
return tmp_5pts
|
||||
else:
|
||||
raise FaceWarpException(
|
||||
'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
|
||||
|
||||
# check output size
|
||||
if not (0 <= inner_padding_factor <= 1.0):
|
||||
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
|
||||
|
||||
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
|
||||
and output_size is None):
|
||||
output_size = tmp_crop_size * \
|
||||
(1 + inner_padding_factor * 2).astype(np.int32)
|
||||
output_size += np.array(outer_padding)
|
||||
print(' deduced from paddings, output_size = ', output_size)
|
||||
|
||||
if not (outer_padding[0] < output_size[0]
|
||||
and outer_padding[1] < output_size[1]):
|
||||
raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
|
||||
'and outer_padding[1] < output_size[1])')
|
||||
|
||||
# 1) pad the inner region according inner_padding_factor
|
||||
# print('---> STEP1: pad the inner region according inner_padding_factor')
|
||||
if inner_padding_factor > 0:
|
||||
size_diff = tmp_crop_size * inner_padding_factor * 2
|
||||
tmp_5pts += size_diff / 2
|
||||
tmp_crop_size += np.round(size_diff).astype(np.int32)
|
||||
|
||||
# print(' crop_size = ', tmp_crop_size)
|
||||
# print(' reference_5pts = ', tmp_5pts)
|
||||
|
||||
# 2) resize the padded inner region
|
||||
# print('---> STEP2: resize the padded inner region')
|
||||
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
|
||||
# print(' crop_size = ', tmp_crop_size)
|
||||
# print(' size_bf_outer_pad = ', size_bf_outer_pad)
|
||||
|
||||
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
|
||||
raise FaceWarpException('Must have (output_size - outer_padding)'
|
||||
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
|
||||
|
||||
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
|
||||
# print(' resize scale_factor = ', scale_factor)
|
||||
tmp_5pts = tmp_5pts * scale_factor
|
||||
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
|
||||
# tmp_5pts = tmp_5pts + size_diff / 2
|
||||
tmp_crop_size = size_bf_outer_pad
|
||||
# print(' crop_size = ', tmp_crop_size)
|
||||
# print(' reference_5pts = ', tmp_5pts)
|
||||
|
||||
# 3) add outer_padding to make output_size
|
||||
reference_5point = tmp_5pts + np.array(outer_padding)
|
||||
tmp_crop_size = output_size
|
||||
# print('---> STEP3: add outer_padding to make output_size')
|
||||
# print(' crop_size = ', tmp_crop_size)
|
||||
# print(' reference_5pts = ', tmp_5pts)
|
||||
#
|
||||
# print('===> end get_reference_facial_points\n')
|
||||
|
||||
return reference_5point
|
||||
|
||||
|
||||
def get_affine_transform_matrix(src_pts, dst_pts):
|
||||
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
|
||||
n_pts = src_pts.shape[0]
|
||||
ones = np.ones((n_pts, 1), src_pts.dtype)
|
||||
src_pts_ = np.hstack([src_pts, ones])
|
||||
dst_pts_ = np.hstack([dst_pts, ones])
|
||||
|
||||
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
|
||||
|
||||
if rank == 3:
|
||||
tfm = np.float32([
|
||||
[A[0, 0], A[1, 0], A[2, 0]],
|
||||
[A[0, 1], A[1, 1], A[2, 1]]
|
||||
])
|
||||
elif rank == 2:
|
||||
tfm = np.float32([
|
||||
[A[0, 0], A[1, 0], 0],
|
||||
[A[0, 1], A[1, 1], 0]
|
||||
])
|
||||
|
||||
return tfm
|
||||
|
||||
|
||||
def warp_and_crop_face(src_img,
|
||||
facial_pts,
|
||||
reference_pts=None,
|
||||
crop_size=(96, 112),
|
||||
align_type='smilarity'): #smilarity cv2_affine affine
|
||||
if reference_pts is None:
|
||||
if crop_size[0] == 96 and crop_size[1] == 112:
|
||||
reference_pts = REFERENCE_FACIAL_POINTS
|
||||
else:
|
||||
default_square = False
|
||||
inner_padding_factor = 0
|
||||
outer_padding = (0, 0)
|
||||
output_size = crop_size
|
||||
|
||||
reference_pts = get_reference_facial_points(output_size,
|
||||
inner_padding_factor,
|
||||
outer_padding,
|
||||
default_square)
|
||||
|
||||
ref_pts = np.float32(reference_pts)
|
||||
ref_pts_shp = ref_pts.shape
|
||||
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
||||
raise FaceWarpException(
|
||||
'reference_pts.shape must be (K,2) or (2,K) and K>2')
|
||||
|
||||
if ref_pts_shp[0] == 2:
|
||||
ref_pts = ref_pts.T
|
||||
|
||||
src_pts = np.float32(facial_pts)
|
||||
src_pts_shp = src_pts.shape
|
||||
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
||||
raise FaceWarpException(
|
||||
'facial_pts.shape must be (K,2) or (2,K) and K>2')
|
||||
|
||||
if src_pts_shp[0] == 2:
|
||||
src_pts = src_pts.T
|
||||
|
||||
if src_pts.shape != ref_pts.shape:
|
||||
raise FaceWarpException(
|
||||
'facial_pts and reference_pts must have the same shape')
|
||||
|
||||
if align_type is 'cv2_affine':
|
||||
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
||||
tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
|
||||
elif align_type is 'affine':
|
||||
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
||||
tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
|
||||
else:
|
||||
params, scale = _umeyama(src_pts, ref_pts)
|
||||
tfm = params[:2, :]
|
||||
|
||||
params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
|
||||
tfm_inv = params[:2, :]
|
||||
|
||||
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
|
||||
|
||||
return face_img, tfm_inv
|
|
@ -1,631 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from utils import utils_image as util
|
||||
|
||||
import random
|
||||
from scipy import ndimage
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf-1)*0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w-1)
|
||||
y1 = np.clip(y1, 0, h-1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
'''
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
||||
k = k.repeat(1,c,1,1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z-MU
|
||||
ZZ_t = ZZ.transpose(0,1,3,2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
|
||||
arg = -(x*x + y*y)/(2*std*std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h/sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha,1])])
|
||||
h1 = alpha/(alpha+1)
|
||||
h2 = (1-alpha)/(alpha+1)
|
||||
h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
'''
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
'''
|
||||
if filter_type == 'gaussian':
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == 'laplacian':
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = util.imresize_np(x, scale=1/sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
|
||||
Return:
|
||||
downsampled LR image
|
||||
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
|
||||
''' bicubic downsampling + blur
|
||||
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
|
||||
Return:
|
||||
downsampled LR image
|
||||
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
#x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2*sf
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2*random.random()
|
||||
l2 = wd2*random.random()
|
||||
k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5/sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2/255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3,3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2/255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3,3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = 10**(2*random.random()+2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(30, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h-lq_patchsize)
|
||||
rnd_w = random.randint(0, w-lq_patchsize)
|
||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize*sf or w < lq_patchsize*sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
|
||||
else:
|
||||
img = util.imresize_np(img, 1/2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1,2*sf)
|
||||
img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
|
||||
|
||||
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
|
||||
"""
|
||||
This is an extended degradation model by combining
|
||||
the degradation models of BSRGAN and Real-ESRGAN
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
use_shuffle: the degradation shuffle
|
||||
use_sharp: sharpening the img
|
||||
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize*sf or w < lq_patchsize*sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
if use_sharp:
|
||||
img = add_sharpening(img)
|
||||
hq = img.copy()
|
||||
|
||||
if random.random() < shuffle_prob:
|
||||
shuffle_order = random.sample(range(13), 13)
|
||||
else:
|
||||
shuffle_order = list(range(13))
|
||||
# local shuffle for noise, JPEG is always the last one
|
||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
||||
|
||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 1:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 2:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 3:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 4:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 5:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
elif i == 6:
|
||||
img = add_JPEG_noise(img)
|
||||
elif i == 7:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 8:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 9:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 10:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 11:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 12:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
else:
|
||||
print('check the shuffle!')
|
||||
|
||||
# resize to desired size
|
||||
img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = util.imread_uint('utils/test.png', 3)
|
||||
img = util.uint2single(img)
|
||||
sf = 4
|
||||
|
||||
for i in range(20):
|
||||
img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
|
||||
print(i)
|
||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
|
||||
img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i)+'.png')
|
||||
|
||||
# for i in range(10):
|
||||
# img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
|
||||
# print(i)
|
||||
# lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
|
||||
# img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
# util.imsave(img_concat, str(i)+'.png')
|
||||
|
||||
# run utils/utils_blindsr.py
|
|
@ -1,93 +0,0 @@
|
|||
import math
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
'''
|
||||
borrowed from
|
||||
https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
|
||||
'''
|
||||
|
||||
|
||||
def sizeof_fmt(size, suffix='B'):
|
||||
"""Get human readable file size.
|
||||
Args:
|
||||
size (int): File size.
|
||||
suffix (str): Suffix. Default: 'B'.
|
||||
Return:
|
||||
str: Formated file siz.
|
||||
"""
|
||||
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
||||
if abs(size) < 1024.0:
|
||||
return f'{size:3.1f} {unit}{suffix}'
|
||||
size /= 1024.0
|
||||
return f'{size:3.1f} Y{suffix}'
|
||||
|
||||
|
||||
def download_file_from_google_drive(file_id, save_path):
|
||||
"""Download files from google drive.
|
||||
Ref:
|
||||
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
||||
Args:
|
||||
file_id (str): File id.
|
||||
save_path (str): Save path.
|
||||
"""
|
||||
|
||||
session = requests.Session()
|
||||
URL = 'https://docs.google.com/uc?export=download'
|
||||
params = {'id': file_id}
|
||||
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
token = get_confirm_token(response)
|
||||
if token:
|
||||
params['confirm'] = token
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
# get file size
|
||||
response_file_size = session.get(
|
||||
URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
||||
if 'Content-Range' in response_file_size.headers:
|
||||
file_size = int(
|
||||
response_file_size.headers['Content-Range'].split('/')[1])
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
save_response_content(response, save_path, file_size)
|
||||
|
||||
|
||||
def get_confirm_token(response):
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith('download_warning'):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def save_response_content(response,
|
||||
destination,
|
||||
file_size=None,
|
||||
chunk_size=32768):
|
||||
if file_size is not None:
|
||||
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
||||
|
||||
readable_file_size = sizeof_fmt(file_size)
|
||||
else:
|
||||
pbar = None
|
||||
|
||||
with open(destination, 'wb') as f:
|
||||
downloaded_size = 0
|
||||
for chunk in response.iter_content(chunk_size):
|
||||
downloaded_size += chunk_size
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
|
||||
f'/ {readable_file_size}')
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
|
||||
save_path = 'BSRGAN.pth'
|
||||
download_file_from_google_drive(file_id, save_path)
|
|
@ -1,848 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from utils import utils_image as util
|
||||
import random
|
||||
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
import scipy.io as io
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import interp2d
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# modified by Kai Zhang (github: https://github.com/cszn)
|
||||
# 03/03/2020
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# calculate PCA projection matrix
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def get_pca_matrix(x, dim_pca=15):
|
||||
"""
|
||||
Args:
|
||||
x: 225x10000 matrix
|
||||
dim_pca: 15
|
||||
Returns:
|
||||
pca_matrix: 15x225
|
||||
"""
|
||||
C = np.dot(x, x.T)
|
||||
w, v = scipy.linalg.eigh(C)
|
||||
pca_matrix = v[:, -dim_pca:].T
|
||||
|
||||
return pca_matrix
|
||||
|
||||
|
||||
def show_pca(x):
|
||||
"""
|
||||
x: PCA projection matrix, e.g., 15x225
|
||||
"""
|
||||
for i in range(x.shape[0]):
|
||||
xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
|
||||
util.surf(xc)
|
||||
|
||||
|
||||
def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
|
||||
kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
|
||||
for i in range(num_samples):
|
||||
|
||||
theta = np.pi*np.random.rand(1)
|
||||
l1 = 0.1+l_max*np.random.rand(1)
|
||||
l2 = 0.1+(l1-0.1)*np.random.rand(1)
|
||||
|
||||
k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
|
||||
|
||||
# util.imshow(k)
|
||||
|
||||
kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
|
||||
|
||||
# io.savemat('k.mat', {'k': kernels})
|
||||
|
||||
pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
|
||||
|
||||
io.savemat(path, {'p': pca_matrix})
|
||||
|
||||
return pca_matrix
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# shifted anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z-MU
|
||||
ZZ_t = ZZ.transpose(0,1,3,2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
sf = random.choice([1, 2, 3, 4])
|
||||
scale_factor = np.array([sf, sf])
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z-MU
|
||||
ZZ_t = ZZ.transpose(0,1,3,2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = util.imresize_np(x, scale=1/sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
|
||||
''' bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
#x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
'''
|
||||
# =================
|
||||
# Numpy
|
||||
# =================
|
||||
'''
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH, image or kernel
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf-1)*0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w-1)
|
||||
y1 = np.clip(y1, 0, h-1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
'''
|
||||
# =================
|
||||
# pytorch
|
||||
# =================
|
||||
'''
|
||||
|
||||
|
||||
def splits(a, sf):
|
||||
'''
|
||||
a: tensor NxCxWxHx2
|
||||
sf: scale factor
|
||||
out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
|
||||
'''
|
||||
b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
|
||||
b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
|
||||
return b
|
||||
|
||||
|
||||
def c2c(x):
|
||||
return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
|
||||
|
||||
|
||||
def r2c(x):
|
||||
return torch.stack([x, torch.zeros_like(x)], -1)
|
||||
|
||||
|
||||
def cdiv(x, y):
|
||||
a, b = x[..., 0], x[..., 1]
|
||||
c, d = y[..., 0], y[..., 1]
|
||||
cd2 = c**2 + d**2
|
||||
return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
|
||||
|
||||
|
||||
def csum(x, y):
|
||||
return torch.stack([x[..., 0] + y, x[..., 1]], -1)
|
||||
|
||||
|
||||
def cabs(x):
|
||||
return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
|
||||
|
||||
|
||||
def cmul(t1, t2):
|
||||
'''
|
||||
complex multiplication
|
||||
t1: NxCxHxWx2
|
||||
output: NxCxHxWx2
|
||||
'''
|
||||
real1, imag1 = t1[..., 0], t1[..., 1]
|
||||
real2, imag2 = t2[..., 0], t2[..., 1]
|
||||
return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
|
||||
|
||||
|
||||
def cconj(t, inplace=False):
|
||||
'''
|
||||
# complex's conjugation
|
||||
t: NxCxHxWx2
|
||||
output: NxCxHxWx2
|
||||
'''
|
||||
c = t.clone() if not inplace else t
|
||||
c[..., 1] *= -1
|
||||
return c
|
||||
|
||||
|
||||
def rfft(t):
|
||||
return torch.rfft(t, 2, onesided=False)
|
||||
|
||||
|
||||
def irfft(t):
|
||||
return torch.irfft(t, 2, onesided=False)
|
||||
|
||||
|
||||
def fft(t):
|
||||
return torch.fft(t, 2)
|
||||
|
||||
|
||||
def ifft(t):
|
||||
return torch.ifft(t, 2)
|
||||
|
||||
|
||||
def p2o(psf, shape):
|
||||
'''
|
||||
Args:
|
||||
psf: NxCxhxw
|
||||
shape: [H,W]
|
||||
|
||||
Returns:
|
||||
otf: NxCxHxWx2
|
||||
'''
|
||||
otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
|
||||
otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
|
||||
for axis, axis_size in enumerate(psf.shape[2:]):
|
||||
otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
|
||||
otf = torch.rfft(otf, 2, onesided=False)
|
||||
n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
|
||||
otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
|
||||
return otf
|
||||
|
||||
|
||||
'''
|
||||
# =================
|
||||
PyTorch
|
||||
# =================
|
||||
'''
|
||||
|
||||
def INVLS_pytorch(FB, FBC, F2B, FR, tau, sf=2):
|
||||
'''
|
||||
FB: NxCxWxHx2
|
||||
F2B: NxCxWxHx2
|
||||
|
||||
x1 = FB.*FR;
|
||||
FBR = BlockMM(nr,nc,Nb,m,x1);
|
||||
invW = BlockMM(nr,nc,Nb,m,F2B);
|
||||
invWBR = FBR./(invW + tau*Nb);
|
||||
fun = @(block_struct) block_struct.data.*invWBR;
|
||||
FCBinvWBR = blockproc(FBC,[nr,nc],fun);
|
||||
FX = (FR-FCBinvWBR)/tau;
|
||||
Xest = real(ifft2(FX));
|
||||
'''
|
||||
x1 = cmul(FB, FR)
|
||||
FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
|
||||
invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
|
||||
invWBR = cdiv(FBR, csum(invW, tau))
|
||||
FCBinvWBR = cmul(FBC, invWBR.repeat(1,1,sf,sf,1))
|
||||
FX = (FR-FCBinvWBR)/tau
|
||||
Xest = torch.irfft(FX, 2, onesided=False)
|
||||
return Xest
|
||||
|
||||
|
||||
def real2complex(x):
|
||||
return torch.stack([x, torch.zeros_like(x)], -1)
|
||||
|
||||
|
||||
def modcrop(img, sf):
|
||||
'''
|
||||
img: tensor image, NxCxWxH or CxWxH or WxH
|
||||
sf: scale factor
|
||||
'''
|
||||
w, h = img.shape[-2:]
|
||||
im = img.clone()
|
||||
return im[..., :w - w % sf, :h - h % sf]
|
||||
|
||||
|
||||
def upsample(x, sf=3, center=False):
|
||||
'''
|
||||
x: tensor image, NxCxWxH
|
||||
'''
|
||||
st = (sf-1)//2 if center else 0
|
||||
z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
|
||||
z[..., st::sf, st::sf].copy_(x)
|
||||
return z
|
||||
|
||||
|
||||
def downsample(x, sf=3, center=False):
|
||||
st = (sf-1)//2 if center else 0
|
||||
return x[..., st::sf, st::sf]
|
||||
|
||||
|
||||
def circular_pad(x, pad):
|
||||
'''
|
||||
# x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
|
||||
'''
|
||||
x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
|
||||
x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
|
||||
x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
|
||||
x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
|
||||
return x
|
||||
|
||||
|
||||
def pad_circular(input, padding):
|
||||
# type: (Tensor, List[int]) -> Tensor
|
||||
"""
|
||||
Arguments
|
||||
:param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
|
||||
:param padding: (tuple): m-elem tuple where m is the degree of convolution
|
||||
Returns
|
||||
:return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
|
||||
H + 2 * padding[1]], W + 2 * padding[2]))`
|
||||
"""
|
||||
offset = 3
|
||||
for dimension in range(input.dim() - offset + 1):
|
||||
input = dim_pad_circular(input, padding[dimension], dimension + offset)
|
||||
return input
|
||||
|
||||
|
||||
def dim_pad_circular(input, padding, dimension):
|
||||
# type: (Tensor, int, int) -> Tensor
|
||||
input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
|
||||
[slice(0, padding)]]], dim=dimension - 1)
|
||||
input = torch.cat([input[[slice(None)] * (dimension - 1) +
|
||||
[slice(-2 * padding, -padding)]], input], dim=dimension - 1)
|
||||
return input
|
||||
|
||||
|
||||
def imfilter(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, cx1xhxw
|
||||
'''
|
||||
x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
|
||||
x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
|
||||
return x
|
||||
|
||||
|
||||
def G(x, k, sf=3, center=False):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, cx1xhxw
|
||||
sf: scale factor
|
||||
center: the first one or the moddle one
|
||||
|
||||
Matlab function:
|
||||
tmp = imfilter(x,h,'circular');
|
||||
y = downsample2(tmp,K);
|
||||
'''
|
||||
x = downsample(imfilter(x, k), sf=sf, center=center)
|
||||
return x
|
||||
|
||||
|
||||
def Gt(x, k, sf=3, center=False):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, cx1xhxw
|
||||
sf: scale factor
|
||||
center: the first one or the moddle one
|
||||
|
||||
Matlab function:
|
||||
tmp = upsample2(x,K);
|
||||
y = imfilter(tmp,h,'circular');
|
||||
'''
|
||||
x = imfilter(upsample(x, sf=sf, center=center), k)
|
||||
return x
|
||||
|
||||
|
||||
def interpolation_down(x, sf, center=False):
|
||||
mask = torch.zeros_like(x)
|
||||
if center:
|
||||
start = torch.tensor((sf-1)//2)
|
||||
mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
|
||||
LR = x[..., start::sf, start::sf]
|
||||
else:
|
||||
mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
|
||||
LR = x[..., ::sf, ::sf]
|
||||
y = x.mul(mask)
|
||||
|
||||
return LR, y, mask
|
||||
|
||||
|
||||
'''
|
||||
# =================
|
||||
Numpy
|
||||
# =================
|
||||
'''
|
||||
|
||||
|
||||
def blockproc(im, blocksize, fun):
|
||||
xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
|
||||
xblocks_proc = []
|
||||
for xb in xblocks:
|
||||
yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
|
||||
yblocks_proc = []
|
||||
for yb in yblocks:
|
||||
yb_proc = fun(yb)
|
||||
yblocks_proc.append(yb_proc)
|
||||
xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
|
||||
|
||||
proc = np.concatenate(xblocks_proc, axis=0)
|
||||
|
||||
return proc
|
||||
|
||||
|
||||
def fun_reshape(a):
|
||||
return np.reshape(a, (-1,1,a.shape[-1]), order='F')
|
||||
|
||||
|
||||
def fun_mul(a, b):
|
||||
return a*b
|
||||
|
||||
|
||||
def BlockMM(nr, nc, Nb, m, x1):
|
||||
'''
|
||||
myfun = @(block_struct) reshape(block_struct.data,m,1);
|
||||
x1 = blockproc(x1,[nr nc],myfun);
|
||||
x1 = reshape(x1,m,Nb);
|
||||
x1 = sum(x1,2);
|
||||
x = reshape(x1,nr,nc);
|
||||
'''
|
||||
fun = fun_reshape
|
||||
x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
|
||||
x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
|
||||
x1 = np.sum(x1, 1)
|
||||
x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
|
||||
return x
|
||||
|
||||
|
||||
def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
|
||||
'''
|
||||
x1 = FB.*FR;
|
||||
FBR = BlockMM(nr,nc,Nb,m,x1);
|
||||
invW = BlockMM(nr,nc,Nb,m,F2B);
|
||||
invWBR = FBR./(invW + tau*Nb);
|
||||
fun = @(block_struct) block_struct.data.*invWBR;
|
||||
FCBinvWBR = blockproc(FBC,[nr,nc],fun);
|
||||
FX = (FR-FCBinvWBR)/tau;
|
||||
Xest = real(ifft2(FX));
|
||||
'''
|
||||
x1 = FB*FR
|
||||
FBR = BlockMM(nr, nc, Nb, m, x1)
|
||||
invW = BlockMM(nr, nc, Nb, m, F2B)
|
||||
invWBR = FBR/(invW + tau*Nb)
|
||||
FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
|
||||
FX = (FR-FCBinvWBR)/tau
|
||||
Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
|
||||
return Xest
|
||||
|
||||
|
||||
def psf2otf(psf, shape=None):
|
||||
"""
|
||||
Convert point-spread function to optical transfer function.
|
||||
Compute the Fast Fourier Transform (FFT) of the point-spread
|
||||
function (PSF) array and creates the optical transfer function (OTF)
|
||||
array that is not influenced by the PSF off-centering.
|
||||
By default, the OTF array is the same size as the PSF array.
|
||||
To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
|
||||
post-pads the PSF array (down or to the right) with zeros to match
|
||||
dimensions specified in OUTSIZE, then circularly shifts the values of
|
||||
the PSF array up (or to the left) until the central pixel reaches (1,1)
|
||||
position.
|
||||
Parameters
|
||||
----------
|
||||
psf : `numpy.ndarray`
|
||||
PSF array
|
||||
shape : int
|
||||
Output shape of the OTF array
|
||||
Returns
|
||||
-------
|
||||
otf : `numpy.ndarray`
|
||||
OTF array
|
||||
Notes
|
||||
-----
|
||||
Adapted from MATLAB psf2otf function
|
||||
"""
|
||||
if type(shape) == type(None):
|
||||
shape = psf.shape
|
||||
shape = np.array(shape)
|
||||
if np.all(psf == 0):
|
||||
# return np.zeros_like(psf)
|
||||
return np.zeros(shape)
|
||||
if len(psf.shape) == 1:
|
||||
psf = psf.reshape((1, psf.shape[0]))
|
||||
inshape = psf.shape
|
||||
psf = zero_pad(psf, shape, position='corner')
|
||||
for axis, axis_size in enumerate(inshape):
|
||||
psf = np.roll(psf, -int(axis_size / 2), axis=axis)
|
||||
# Compute the OTF
|
||||
otf = np.fft.fft2(psf, axes=(0, 1))
|
||||
# Estimate the rough number of operations involved in the FFT
|
||||
# and discard the PSF imaginary part if within roundoff error
|
||||
# roundoff error = machine epsilon = sys.float_info.epsilon
|
||||
# or np.finfo().eps
|
||||
n_ops = np.sum(psf.size * np.log2(psf.shape))
|
||||
otf = np.real_if_close(otf, tol=n_ops)
|
||||
return otf
|
||||
|
||||
|
||||
def zero_pad(image, shape, position='corner'):
|
||||
"""
|
||||
Extends image to a certain size with zeros
|
||||
Parameters
|
||||
----------
|
||||
image: real 2d `numpy.ndarray`
|
||||
Input image
|
||||
shape: tuple of int
|
||||
Desired output shape of the image
|
||||
position : str, optional
|
||||
The position of the input image in the output one:
|
||||
* 'corner'
|
||||
top-left corner (default)
|
||||
* 'center'
|
||||
centered
|
||||
Returns
|
||||
-------
|
||||
padded_img: real `numpy.ndarray`
|
||||
The zero-padded image
|
||||
"""
|
||||
shape = np.asarray(shape, dtype=int)
|
||||
imshape = np.asarray(image.shape, dtype=int)
|
||||
if np.alltrue(imshape == shape):
|
||||
return image
|
||||
if np.any(shape <= 0):
|
||||
raise ValueError("ZERO_PAD: null or negative shape given")
|
||||
dshape = shape - imshape
|
||||
if np.any(dshape < 0):
|
||||
raise ValueError("ZERO_PAD: target size smaller than source one")
|
||||
pad_img = np.zeros(shape, dtype=image.dtype)
|
||||
idx, idy = np.indices(imshape)
|
||||
if position == 'center':
|
||||
if np.any(dshape % 2 != 0):
|
||||
raise ValueError("ZERO_PAD: source and target shapes "
|
||||
"have different parity.")
|
||||
offx, offy = dshape // 2
|
||||
else:
|
||||
offx, offy = (0, 0)
|
||||
pad_img[idx + offx, idy + offy] = image
|
||||
return pad_img
|
||||
|
||||
|
||||
def upsample_np(x, sf=3, center=False):
|
||||
st = (sf-1)//2 if center else 0
|
||||
z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
|
||||
z[st::sf, st::sf, ...] = x
|
||||
return z
|
||||
|
||||
|
||||
def downsample_np(x, sf=3, center=False):
|
||||
st = (sf-1)//2 if center else 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def imfilter_np(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, cx1xhxw
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def G_np(x, k, sf=3, center=False):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, cx1xhxw
|
||||
|
||||
Matlab function:
|
||||
tmp = imfilter(x,h,'circular');
|
||||
y = downsample2(tmp,K);
|
||||
'''
|
||||
x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
|
||||
return x
|
||||
|
||||
|
||||
def Gt_np(x, k, sf=3, center=False):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, cx1xhxw
|
||||
|
||||
Matlab function:
|
||||
tmp = upsample2(x,K);
|
||||
y = imfilter(tmp,h,'circular');
|
||||
'''
|
||||
x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = util.imread_uint('test.bmp', 3)
|
||||
|
||||
img = util.uint2single(img)
|
||||
k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
|
||||
util.imshow(k*10)
|
||||
|
||||
|
||||
for sf in [2, 3, 4]:
|
||||
|
||||
# modcrop
|
||||
img = modcrop_np(img, sf=sf)
|
||||
|
||||
# 1) bicubic degradation
|
||||
img_b = bicubic_degradation(img, sf=sf)
|
||||
print(img_b.shape)
|
||||
|
||||
# 2) srmd degradation
|
||||
img_s = srmd_degradation(img, k, sf=sf)
|
||||
print(img_s.shape)
|
||||
|
||||
# 3) dpsr degradation
|
||||
img_d = dpsr_degradation(img, k, sf=sf)
|
||||
print(img_d.shape)
|
||||
|
||||
# 4) classical degradation
|
||||
img_d = classical_degradation(img, k, sf=sf)
|
||||
print(img_d.shape)
|
||||
|
||||
k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
|
||||
#print(k)
|
||||
# util.imshow(k*10)
|
||||
|
||||
k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
|
||||
# util.imshow(k*10)
|
||||
|
||||
|
||||
# PCA
|
||||
# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
|
||||
# print(pca_matrix.shape)
|
||||
# show_pca(pca_matrix)
|
||||
# run utils/utils_sisr.py
|
||||
# run utils_sisr.py
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,493 +0,0 @@
|
|||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
from os import path as osp
|
||||
from torch.nn import functional as F
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
||||
"""Scan a directory to find the interested files.
|
||||
|
||||
Args:
|
||||
dir_path (str): Path of the directory.
|
||||
suffix (str | tuple(str), optional): File suffix that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
full_path (bool, optional): If set to True, include the dir_path.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative paths.
|
||||
"""
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, suffix, recursive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
if full_path:
|
||||
return_path = entry.path
|
||||
else:
|
||||
return_path = osp.relpath(entry.path, root)
|
||||
|
||||
if suffix is None:
|
||||
yield return_path
|
||||
elif return_path.endswith(suffix):
|
||||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
||||
|
||||
|
||||
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
||||
"""Read a sequence of images from a given folder path.
|
||||
|
||||
Args:
|
||||
path (list[str] | str): List of image paths or image folder path.
|
||||
require_mod_crop (bool): Require mod crop for each image.
|
||||
Default: False.
|
||||
scale (int): Scale factor for mod_crop. Default: 1.
|
||||
return_imgname(bool): Whether return image names. Default False.
|
||||
|
||||
Returns:
|
||||
Tensor: size (t, c, h, w), RGB, [0, 1].
|
||||
list[str]: Returned image name list.
|
||||
"""
|
||||
if isinstance(path, list):
|
||||
img_paths = path
|
||||
else:
|
||||
img_paths = sorted(list(scandir(path, full_path=True)))
|
||||
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
||||
|
||||
if require_mod_crop:
|
||||
imgs = [mod_crop(img, scale) for img in imgs]
|
||||
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
|
||||
if return_imgname:
|
||||
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
||||
return imgs, imgnames
|
||||
else:
|
||||
return imgs
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
if img.dtype == 'float64':
|
||||
img = img.astype('float32')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
|
||||
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
||||
"""Convert torch Tensors into image numpy arrays.
|
||||
|
||||
After clamping to [min, max], values will be normalized to [0, 1].
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list[Tensor]): Accept shapes:
|
||||
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
||||
2) 3D Tensor of shape (3/1 x H x W);
|
||||
3) 2D Tensor of shape (H x W).
|
||||
Tensor channel should be in RGB order.
|
||||
rgb2bgr (bool): Whether to change rgb to bgr.
|
||||
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
||||
to uint8 type with range [0, 255]; otherwise, float type with
|
||||
range [0, 1]. Default: ``np.uint8``.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
|
||||
Returns:
|
||||
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
||||
shape (H x W). The channel order is BGR.
|
||||
"""
|
||||
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
||||
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
|
||||
if torch.is_tensor(tensor):
|
||||
tensor = [tensor]
|
||||
result = []
|
||||
for _tensor in tensor:
|
||||
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
|
||||
n_dim = _tensor.dim()
|
||||
if n_dim == 4:
|
||||
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 3:
|
||||
img_np = _tensor.numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image
|
||||
img_np = np.squeeze(img_np, axis=2)
|
||||
else:
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 2:
|
||||
img_np = _tensor.numpy()
|
||||
else:
|
||||
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
||||
if out_type == np.uint8:
|
||||
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
||||
img_np = (img_np * 255.0).round()
|
||||
img_np = img_np.astype(out_type)
|
||||
result.append(img_np)
|
||||
if len(result) == 1:
|
||||
result = result[0]
|
||||
return result
|
||||
|
||||
|
||||
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
||||
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
||||
|
||||
We use vertical flip and transpose for rotation implementation.
|
||||
All the images in the list use the same augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
||||
is an ndarray, it will be transformed to a list.
|
||||
hflip (bool): Horizontal flip. Default: True.
|
||||
rotation (bool): Ratotation. Default: True.
|
||||
flows (list[ndarray]: Flows to be augmented. If the input is an
|
||||
ndarray, it will be transformed to a list.
|
||||
Dimension is (h, w, 2). Default: None.
|
||||
return_status (bool): Return the status of flip and rotation.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: Augmented images and flows. If returned
|
||||
results only have one element, just return ndarray.
|
||||
|
||||
"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rotation and random.random() < 0.5
|
||||
rot90 = rotation and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip: # horizontal
|
||||
cv2.flip(img, 1, img)
|
||||
if vflip: # vertical
|
||||
cv2.flip(img, 0, img)
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
def _augment_flow(flow):
|
||||
if hflip: # horizontal
|
||||
cv2.flip(flow, 1, flow)
|
||||
flow[:, :, 0] *= -1
|
||||
if vflip: # vertical
|
||||
cv2.flip(flow, 0, flow)
|
||||
flow[:, :, 1] *= -1
|
||||
if rot90:
|
||||
flow = flow.transpose(1, 0, 2)
|
||||
flow = flow[:, :, [1, 0]]
|
||||
return flow
|
||||
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
imgs = [_augment(img) for img in imgs]
|
||||
if len(imgs) == 1:
|
||||
imgs = imgs[0]
|
||||
|
||||
if flows is not None:
|
||||
if not isinstance(flows, list):
|
||||
flows = [flows]
|
||||
flows = [_augment_flow(flow) for flow in flows]
|
||||
if len(flows) == 1:
|
||||
flows = flows[0]
|
||||
return imgs, flows
|
||||
else:
|
||||
if return_status:
|
||||
return imgs, (hflip, vflip, rot90)
|
||||
else:
|
||||
return imgs
|
||||
|
||||
|
||||
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
||||
"""Paired random crop. Support Numpy array and Tensor inputs.
|
||||
|
||||
It crops lists of lq and gt images with corresponding locations.
|
||||
|
||||
Args:
|
||||
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
||||
should have the same shape. If the input is an ndarray, it will
|
||||
be transformed to a list containing itself.
|
||||
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
||||
should have the same shape. If the input is an ndarray, it will
|
||||
be transformed to a list containing itself.
|
||||
gt_patch_size (int): GT patch size.
|
||||
scale (int): Scale factor.
|
||||
gt_path (str): Path to ground-truth. Default: None.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
||||
only have one element, just return ndarray.
|
||||
"""
|
||||
|
||||
if not isinstance(img_gts, list):
|
||||
img_gts = [img_gts]
|
||||
if not isinstance(img_lqs, list):
|
||||
img_lqs = [img_lqs]
|
||||
|
||||
# determine input type: Numpy array or Tensor
|
||||
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
||||
|
||||
if input_type == 'Tensor':
|
||||
h_lq, w_lq = img_lqs[0].size()[-2:]
|
||||
h_gt, w_gt = img_gts[0].size()[-2:]
|
||||
else:
|
||||
h_lq, w_lq = img_lqs[0].shape[0:2]
|
||||
h_gt, w_gt = img_gts[0].shape[0:2]
|
||||
lq_patch_size = gt_patch_size // scale
|
||||
|
||||
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
||||
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
||||
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
||||
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
||||
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
||||
f'({lq_patch_size}, {lq_patch_size}). '
|
||||
f'Please remove {gt_path}.')
|
||||
|
||||
# randomly choose top and left coordinates for lq patch
|
||||
top = random.randint(0, h_lq - lq_patch_size)
|
||||
left = random.randint(0, w_lq - lq_patch_size)
|
||||
|
||||
# crop lq patch
|
||||
if input_type == 'Tensor':
|
||||
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
||||
else:
|
||||
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
||||
|
||||
# crop corresponding gt patch
|
||||
top_gt, left_gt = int(top * scale), int(left * scale)
|
||||
if input_type == 'Tensor':
|
||||
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
||||
else:
|
||||
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
||||
if len(img_gts) == 1:
|
||||
img_gts = img_gts[0]
|
||||
if len(img_lqs) == 1:
|
||||
img_lqs = img_lqs[0]
|
||||
return img_gts, img_lqs
|
||||
|
||||
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
||||
class BaseStorageBackend(metaclass=ABCMeta):
|
||||
"""Abstract class of storage backends.
|
||||
|
||||
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
||||
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
||||
as texts.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, filepath):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_text(self, filepath):
|
||||
pass
|
||||
|
||||
|
||||
class MemcachedBackend(BaseStorageBackend):
|
||||
"""Memcached storage backend.
|
||||
|
||||
Attributes:
|
||||
server_list_cfg (str): Config file for memcached server list.
|
||||
client_cfg (str): Config file for memcached client.
|
||||
sys_path (str | None): Additional path to be appended to `sys.path`.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
||||
if sys_path is not None:
|
||||
import sys
|
||||
sys.path.append(sys_path)
|
||||
try:
|
||||
import mc
|
||||
except ImportError:
|
||||
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
||||
|
||||
self.server_list_cfg = server_list_cfg
|
||||
self.client_cfg = client_cfg
|
||||
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
||||
# mc.pyvector servers as a point which points to a memory cache
|
||||
self._mc_buffer = mc.pyvector()
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
import mc
|
||||
self._client.Get(filepath, self._mc_buffer)
|
||||
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HardDiskBackend(BaseStorageBackend):
|
||||
"""Raw hard disks storage backend."""
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
with open(filepath, 'rb') as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
filepath = str(filepath)
|
||||
with open(filepath, 'r') as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
|
||||
class LmdbBackend(BaseStorageBackend):
|
||||
"""Lmdb storage backend.
|
||||
|
||||
Args:
|
||||
db_paths (str | list[str]): Lmdb database paths.
|
||||
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
||||
readonly (bool, optional): Lmdb environment parameter. If True,
|
||||
disallow any write operations. Default: True.
|
||||
lock (bool, optional): Lmdb environment parameter. If False, when
|
||||
concurrent access occurs, do not lock the database. Default: False.
|
||||
readahead (bool, optional): Lmdb environment parameter. If False,
|
||||
disable the OS filesystem readahead mechanism, which may improve
|
||||
random read performance when a database is larger than RAM.
|
||||
Default: False.
|
||||
|
||||
Attributes:
|
||||
db_paths (list): Lmdb database path.
|
||||
_client (list): A list of several lmdb envs.
|
||||
"""
|
||||
|
||||
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
||||
try:
|
||||
import lmdb
|
||||
except ImportError:
|
||||
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
||||
|
||||
if isinstance(client_keys, str):
|
||||
client_keys = [client_keys]
|
||||
|
||||
if isinstance(db_paths, list):
|
||||
self.db_paths = [str(v) for v in db_paths]
|
||||
elif isinstance(db_paths, str):
|
||||
self.db_paths = [str(db_paths)]
|
||||
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
||||
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
||||
|
||||
self._client = {}
|
||||
for client, path in zip(client_keys, self.db_paths):
|
||||
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
||||
|
||||
def get(self, filepath, client_key):
|
||||
"""Get values according to the filepath from one lmdb named client_key.
|
||||
|
||||
Args:
|
||||
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
||||
client_key (str): Used for distinguishing different lmdb envs.
|
||||
"""
|
||||
filepath = str(filepath)
|
||||
assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
|
||||
client = self._client[client_key]
|
||||
with client.begin(write=False) as txn:
|
||||
value_buf = txn.get(filepath.encode('ascii'))
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileClient(object):
|
||||
"""A general file client to access files in different backend.
|
||||
|
||||
The client loads a file or text in a specified backend from its path
|
||||
and return it as a binary file. it can also register other backend
|
||||
accessor with a given name and backend class.
|
||||
|
||||
Attributes:
|
||||
backend (str): The storage backend type. Options are "disk",
|
||||
"memcached" and "lmdb".
|
||||
client (:obj:`BaseStorageBackend`): The backend object.
|
||||
"""
|
||||
|
||||
_backends = {
|
||||
'disk': HardDiskBackend,
|
||||
'memcached': MemcachedBackend,
|
||||
'lmdb': LmdbBackend,
|
||||
}
|
||||
|
||||
def __init__(self, backend='disk', **kwargs):
|
||||
if backend not in self._backends:
|
||||
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
||||
f' are {list(self._backends.keys())}')
|
||||
self.backend = backend
|
||||
self.client = self._backends[backend](**kwargs)
|
||||
|
||||
def get(self, filepath, client_key='default'):
|
||||
# client_key is used only for lmdb, where different fileclients have
|
||||
# different lmdb environments.
|
||||
if self.backend == 'lmdb':
|
||||
return self.client.get(filepath, client_key)
|
||||
else:
|
||||
return self.client.get(filepath)
|
||||
|
||||
def get_text(self, filepath):
|
||||
return self.client.get_text(filepath)
|
||||
|
||||
|
||||
def imfrombytes(content, flag='color', float32=False):
|
||||
"""Read an image from bytes.
|
||||
|
||||
Args:
|
||||
content (bytes): Image bytes got from files or other streams.
|
||||
flag (str): Flags specifying the color type of a loaded image,
|
||||
candidates are `color`, `grayscale` and `unchanged`.
|
||||
float32 (bool): Whether to change to float32., If True, will also norm
|
||||
to [0, 1]. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: Loaded image array.
|
||||
"""
|
||||
img_np = np.frombuffer(content, np.uint8)
|
||||
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
||||
img = cv2.imdecode(img_np, imread_flags[flag])
|
||||
if float32:
|
||||
img = img.astype(np.float32) / 255.
|
||||
return img
|
||||
|
|
@ -1,555 +0,0 @@
|
|||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
from os import path as osp
|
||||
from torchvision.utils import make_grid
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import six
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import glob
|
||||
import av
|
||||
import io
|
||||
from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
|
||||
CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
|
||||
CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
|
||||
|
||||
if sys.version_info <= (3, 3):
|
||||
FileNotFoundError = IOError
|
||||
else:
|
||||
FileNotFoundError = FileNotFoundError
|
||||
|
||||
|
||||
def is_str(x):
|
||||
"""Whether the input is an string instance."""
|
||||
return isinstance(x, six.string_types)
|
||||
|
||||
|
||||
def is_filepath(x):
|
||||
return is_str(x) or isinstance(x, Path)
|
||||
|
||||
|
||||
def fopen(filepath, *args, **kwargs):
|
||||
if is_str(filepath):
|
||||
return open(filepath, *args, **kwargs)
|
||||
elif isinstance(filepath, Path):
|
||||
return filepath.open(*args, **kwargs)
|
||||
raise ValueError('`filepath` should be a string or a Path')
|
||||
|
||||
|
||||
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
||||
if not osp.isfile(filename):
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
def mkdir_or_exist(dir_name, mode=0o777):
|
||||
if dir_name == '':
|
||||
return
|
||||
dir_name = osp.expanduser(dir_name)
|
||||
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
||||
|
||||
|
||||
def symlink(src, dst, overwrite=True, **kwargs):
|
||||
if os.path.lexists(dst) and overwrite:
|
||||
os.remove(dst)
|
||||
os.symlink(src, dst, **kwargs)
|
||||
|
||||
|
||||
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
|
||||
"""Scan a directory to find the interested files.
|
||||
Args:
|
||||
dir_path (str | :obj:`Path`): Path of the directory.
|
||||
suffix (str | tuple(str), optional): File suffix that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
case_sensitive (bool, optional) : If set to False, ignore the case of
|
||||
suffix. Default: True.
|
||||
Returns:
|
||||
A generator for all the interested files with relative paths.
|
||||
"""
|
||||
if isinstance(dir_path, (str, Path)):
|
||||
dir_path = str(dir_path)
|
||||
else:
|
||||
raise TypeError('"dir_path" must be a string or Path object')
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
if suffix is not None and not case_sensitive:
|
||||
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
|
||||
item.lower() for item in suffix)
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, suffix, recursive, case_sensitive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
rel_path = osp.relpath(entry.path, root)
|
||||
_rel_path = rel_path if case_sensitive else rel_path.lower()
|
||||
if suffix is None or _rel_path.endswith(suffix):
|
||||
yield rel_path
|
||||
elif recursive and os.path.isdir(entry.path):
|
||||
# scan recursively if entry.path is a directory
|
||||
yield from _scandir(entry.path, suffix, recursive,
|
||||
case_sensitive)
|
||||
|
||||
return _scandir(dir_path, suffix, recursive, case_sensitive)
|
||||
|
||||
|
||||
class Cache:
|
||||
|
||||
def __init__(self, capacity):
|
||||
self._cache = OrderedDict()
|
||||
self._capacity = int(capacity)
|
||||
if capacity <= 0:
|
||||
raise ValueError('capacity must be a positive integer')
|
||||
|
||||
@property
|
||||
def capacity(self):
|
||||
return self._capacity
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return len(self._cache)
|
||||
|
||||
def put(self, key, val):
|
||||
if key in self._cache:
|
||||
return
|
||||
if len(self._cache) >= self.capacity:
|
||||
self._cache.popitem(last=False)
|
||||
self._cache[key] = val
|
||||
|
||||
def get(self, key, default=None):
|
||||
val = self._cache[key] if key in self._cache else default
|
||||
return val
|
||||
|
||||
|
||||
class VideoReader:
|
||||
"""Video class with similar usage to a list object.
|
||||
|
||||
This video warpper class provides convenient apis to access frames.
|
||||
There exists an issue of OpenCV's VideoCapture class that jumping to a
|
||||
certain frame may be inaccurate. It is fixed in this class by checking
|
||||
the position after jumping each time.
|
||||
Cache is used when decoding videos. So if the same frame is visited for
|
||||
the second time, there is no need to decode again if it is stored in the
|
||||
cache.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filename, cache_capacity=10):
|
||||
# Check whether the video path is a url
|
||||
if not filename.startswith(('https://', 'http://')):
|
||||
check_file_exist(filename, 'Video file not found: ' + filename)
|
||||
self._vcap = cv2.VideoCapture(filename)
|
||||
assert cache_capacity > 0
|
||||
self._cache = Cache(cache_capacity)
|
||||
self._position = 0
|
||||
# get basic info
|
||||
self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
|
||||
self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
|
||||
self._fps = self._vcap.get(CAP_PROP_FPS)
|
||||
self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
|
||||
self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
|
||||
|
||||
@property
|
||||
def vcap(self):
|
||||
""":obj:`cv2.VideoCapture`: The raw VideoCapture object."""
|
||||
return self._vcap
|
||||
|
||||
@property
|
||||
def opened(self):
|
||||
"""bool: Indicate whether the video is opened."""
|
||||
return self._vcap.isOpened()
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
"""int: Width of video frames."""
|
||||
return self._width
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
"""int: Height of video frames."""
|
||||
return self._height
|
||||
|
||||
@property
|
||||
def resolution(self):
|
||||
"""tuple: Video resolution (width, height)."""
|
||||
return (self._width, self._height)
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
"""float: FPS of the video."""
|
||||
return self._fps
|
||||
|
||||
@property
|
||||
def frame_cnt(self):
|
||||
"""int: Total frames of the video."""
|
||||
return self._frame_cnt
|
||||
|
||||
@property
|
||||
def fourcc(self):
|
||||
"""str: "Four character code" of the video."""
|
||||
return self._fourcc
|
||||
|
||||
@property
|
||||
def position(self):
|
||||
"""int: Current cursor position, indicating frame decoded."""
|
||||
return self._position
|
||||
|
||||
def _get_real_position(self):
|
||||
return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
|
||||
|
||||
def _set_real_position(self, frame_id):
|
||||
self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
|
||||
pos = self._get_real_position()
|
||||
for _ in range(frame_id - pos):
|
||||
self._vcap.read()
|
||||
self._position = frame_id
|
||||
|
||||
def read(self):
|
||||
"""Read the next frame.
|
||||
|
||||
If the next frame have been decoded before and in the cache, then
|
||||
return it directly, otherwise decode, cache and return it.
|
||||
|
||||
Returns:
|
||||
ndarray or None: Return the frame if successful, otherwise None.
|
||||
"""
|
||||
# pos = self._position
|
||||
if self._cache:
|
||||
img = self._cache.get(self._position)
|
||||
if img is not None:
|
||||
ret = True
|
||||
else:
|
||||
if self._position != self._get_real_position():
|
||||
self._set_real_position(self._position)
|
||||
ret, img = self._vcap.read()
|
||||
if ret:
|
||||
self._cache.put(self._position, img)
|
||||
else:
|
||||
ret, img = self._vcap.read()
|
||||
if ret:
|
||||
self._position += 1
|
||||
return img
|
||||
|
||||
def get_frame(self, frame_id):
|
||||
"""Get frame by index.
|
||||
|
||||
Args:
|
||||
frame_id (int): Index of the expected frame, 0-based.
|
||||
|
||||
Returns:
|
||||
ndarray or None: Return the frame if successful, otherwise None.
|
||||
"""
|
||||
if frame_id < 0 or frame_id >= self._frame_cnt:
|
||||
raise IndexError(
|
||||
f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
|
||||
if frame_id == self._position:
|
||||
return self.read()
|
||||
if self._cache:
|
||||
img = self._cache.get(frame_id)
|
||||
if img is not None:
|
||||
self._position = frame_id + 1
|
||||
return img
|
||||
self._set_real_position(frame_id)
|
||||
ret, img = self._vcap.read()
|
||||
if ret:
|
||||
if self._cache:
|
||||
self._cache.put(self._position, img)
|
||||
self._position += 1
|
||||
return img
|
||||
|
||||
def current_frame(self):
|
||||
"""Get the current frame (frame that is just visited).
|
||||
|
||||
Returns:
|
||||
ndarray or None: If the video is fresh, return None, otherwise
|
||||
return the frame.
|
||||
"""
|
||||
if self._position == 0:
|
||||
return None
|
||||
return self._cache.get(self._position - 1)
|
||||
|
||||
def cvt2frames(self,
|
||||
frame_dir,
|
||||
file_start=0,
|
||||
filename_tmpl='{:06d}.jpg',
|
||||
start=0,
|
||||
max_num=0,
|
||||
show_progress=False):
|
||||
"""Convert a video to frame images.
|
||||
|
||||
Args:
|
||||
frame_dir (str): Output directory to store all the frame images.
|
||||
file_start (int): Filenames will start from the specified number.
|
||||
filename_tmpl (str): Filename template with the index as the
|
||||
placeholder.
|
||||
start (int): The starting frame index.
|
||||
max_num (int): Maximum number of frames to be written.
|
||||
show_progress (bool): Whether to show a progress bar.
|
||||
"""
|
||||
mkdir_or_exist(frame_dir)
|
||||
if max_num == 0:
|
||||
task_num = self.frame_cnt - start
|
||||
else:
|
||||
task_num = min(self.frame_cnt - start, max_num)
|
||||
if task_num <= 0:
|
||||
raise ValueError('start must be less than total frame number')
|
||||
if start > 0:
|
||||
self._set_real_position(start)
|
||||
|
||||
def write_frame(file_idx):
|
||||
img = self.read()
|
||||
if img is None:
|
||||
return
|
||||
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
if show_progress:
|
||||
pass
|
||||
#track_progress(write_frame, range(file_start,file_start + task_num))
|
||||
else:
|
||||
for i in range(task_num):
|
||||
write_frame(file_start + i)
|
||||
|
||||
def __len__(self):
|
||||
return self.frame_cnt
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
return [
|
||||
self.get_frame(i)
|
||||
for i in range(*index.indices(self.frame_cnt))
|
||||
]
|
||||
# support negative indexing
|
||||
if index < 0:
|
||||
index += self.frame_cnt
|
||||
if index < 0:
|
||||
raise IndexError('index out of range')
|
||||
return self.get_frame(index)
|
||||
|
||||
def __iter__(self):
|
||||
self._set_real_position(0)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
img = self.read()
|
||||
if img is not None:
|
||||
return img
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
next = __next__
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self._vcap.release()
|
||||
|
||||
|
||||
def frames2video(frame_dir,
|
||||
video_file,
|
||||
fps=30,
|
||||
fourcc='XVID',
|
||||
filename_tmpl='{:06d}.jpg',
|
||||
start=0,
|
||||
end=0,
|
||||
show_progress=False):
|
||||
"""Read the frame images from a directory and join them as a video.
|
||||
|
||||
Args:
|
||||
frame_dir (str): The directory containing video frames.
|
||||
video_file (str): Output filename.
|
||||
fps (float): FPS of the output video.
|
||||
fourcc (str): Fourcc of the output video, this should be compatible
|
||||
with the output file type.
|
||||
filename_tmpl (str): Filename template with the index as the variable.
|
||||
start (int): Starting frame index.
|
||||
end (int): Ending frame index.
|
||||
show_progress (bool): Whether to show a progress bar.
|
||||
"""
|
||||
if end == 0:
|
||||
ext = filename_tmpl.split('.')[-1]
|
||||
end = len([name for name in scandir(frame_dir, ext)])
|
||||
first_file = osp.join(frame_dir, filename_tmpl.format(start))
|
||||
check_file_exist(first_file, 'The start frame not found: ' + first_file)
|
||||
img = cv2.imread(first_file)
|
||||
height, width = img.shape[:2]
|
||||
resolution = (width, height)
|
||||
vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
|
||||
resolution)
|
||||
|
||||
def write_frame(file_idx):
|
||||
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
|
||||
img = cv2.imread(filename)
|
||||
vwriter.write(img)
|
||||
|
||||
if show_progress:
|
||||
pass
|
||||
# track_progress(write_frame, range(start, end))
|
||||
else:
|
||||
for i in range(start, end):
|
||||
write_frame(i)
|
||||
vwriter.release()
|
||||
|
||||
|
||||
def video2images(video_path, output_dir):
|
||||
vidcap = cv2.VideoCapture(video_path)
|
||||
in_fps = vidcap.get(cv2.CAP_PROP_FPS)
|
||||
print('video fps:', in_fps)
|
||||
if not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
loaded, frame = vidcap.read()
|
||||
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
print(f'number of total frames is: {total_frames:06}')
|
||||
for i_frame in range(total_frames):
|
||||
if i_frame % 100 == 0:
|
||||
print(f'{i_frame:06} / {total_frames:06}')
|
||||
frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
|
||||
cv2.imwrite(frame_name, frame)
|
||||
loaded, frame = vidcap.read()
|
||||
|
||||
|
||||
def images2video(image_dir, video_path, fps=24, image_ext='png'):
|
||||
'''
|
||||
#codec = cv2.VideoWriter_fourcc(*'XVID')
|
||||
#codec = cv2.VideoWriter_fourcc('A','V','C','1')
|
||||
#codec = cv2.VideoWriter_fourcc('Y','U','V','1')
|
||||
#codec = cv2.VideoWriter_fourcc('P','I','M','1')
|
||||
#codec = cv2.VideoWriter_fourcc('M','J','P','G')
|
||||
codec = cv2.VideoWriter_fourcc('M','P','4','2')
|
||||
#codec = cv2.VideoWriter_fourcc('D','I','V','3')
|
||||
#codec = cv2.VideoWriter_fourcc('D','I','V','X')
|
||||
#codec = cv2.VideoWriter_fourcc('U','2','6','3')
|
||||
#codec = cv2.VideoWriter_fourcc('I','2','6','3')
|
||||
#codec = cv2.VideoWriter_fourcc('F','L','V','1')
|
||||
#codec = cv2.VideoWriter_fourcc('H','2','6','4')
|
||||
#codec = cv2.VideoWriter_fourcc('A','Y','U','V')
|
||||
#codec = cv2.VideoWriter_fourcc('I','U','Y','V')
|
||||
编码器常用的几种:
|
||||
cv2.VideoWriter_fourcc("I", "4", "2", "0")
|
||||
压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi
|
||||
cv2.VideoWriter_fourcc("P", I", "M", "1")
|
||||
采用mpeg-1编码,文件为avi
|
||||
cv2.VideoWriter_fourcc("X", "V", "T", "D")
|
||||
采用mpeg-4编码,得到视频大小平均 拓展名avi
|
||||
cv2.VideoWriter_fourcc("T", "H", "E", "O")
|
||||
Ogg Vorbis, 拓展名为ogv
|
||||
cv2.VideoWriter_fourcc("F", "L", "V", "1")
|
||||
FLASH视频,拓展名为.flv
|
||||
'''
|
||||
image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
|
||||
print(len(image_files))
|
||||
height, width, _ = cv2.imread(image_files[0]).shape
|
||||
out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V')
|
||||
out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
|
||||
|
||||
for image_file in image_files:
|
||||
img = cv2.imread(image_file)
|
||||
img = cv2.resize(img, (width, height), interpolation=3)
|
||||
out_video.write(img)
|
||||
out_video.release()
|
||||
|
||||
|
||||
def add_video_compression(imgs):
|
||||
codec_type = ['libx264', 'h264', 'mpeg4']
|
||||
codec_prob = [1 / 3., 1 / 3., 1 / 3.]
|
||||
codec = random.choices(codec_type, codec_prob)[0]
|
||||
# codec = 'mpeg4'
|
||||
bitrate = [1e4, 1e5]
|
||||
bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
|
||||
|
||||
buf = io.BytesIO()
|
||||
with av.open(buf, 'w', 'mp4') as container:
|
||||
stream = container.add_stream(codec, rate=1)
|
||||
stream.height = imgs[0].shape[0]
|
||||
stream.width = imgs[0].shape[1]
|
||||
stream.pix_fmt = 'yuv420p'
|
||||
stream.bit_rate = bitrate
|
||||
|
||||
for img in imgs:
|
||||
img = np.uint8((img.clip(0, 1)*255.).round())
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame.pict_type = 'NONE'
|
||||
# pdb.set_trace()
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
outputs = []
|
||||
with av.open(buf, 'r', 'mp4') as container:
|
||||
if container.streams.video:
|
||||
for frame in container.decode(**{'video': 0}):
|
||||
outputs.append(
|
||||
frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
|
||||
|
||||
#outputs = np.stack(outputs, axis=0)
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# -----------------------------------
|
||||
# test VideoReader(filename, cache_capacity=10)
|
||||
# -----------------------------------
|
||||
# video_reader = VideoReader('utils/test.mp4')
|
||||
# from utils import utils_image as util
|
||||
# inputs = []
|
||||
# for frame in video_reader:
|
||||
# print(frame.dtype)
|
||||
# util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
# #util.imshow(np.flip(frame, axis=2))
|
||||
|
||||
# -----------------------------------
|
||||
# test video2images(video_path, output_dir)
|
||||
# -----------------------------------
|
||||
# video2images('utils/test.mp4', 'frames')
|
||||
|
||||
# -----------------------------------
|
||||
# test images2video(image_dir, video_path, fps=24, image_ext='png')
|
||||
# -----------------------------------
|
||||
# images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
|
||||
|
||||
|
||||
# -----------------------------------
|
||||
# test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
|
||||
# -----------------------------------
|
||||
# frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
|
||||
|
||||
|
||||
# -----------------------------------
|
||||
# test add_video_compression(imgs)
|
||||
# -----------------------------------
|
||||
# imgs = []
|
||||
# image_ext = 'png'
|
||||
# frames = 'frames'
|
||||
# from utils import utils_image as util
|
||||
# image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
|
||||
# for i, image_file in enumerate(image_files):
|
||||
# if i < 7:
|
||||
# img = util.imread_uint(image_file, 3)
|
||||
# img = util.uint2single(img)
|
||||
# imgs.append(img)
|
||||
#
|
||||
# results = add_video_compression(imgs)
|
||||
# for i, img in enumerate(results):
|
||||
# util.imshow(util.single2uint(img))
|
||||
# util.imsave(util.single2uint(img),f'{i:05}.png')
|
||||
|
||||
# run utils/utils_video.py
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
{
|
||||
"task": "drunet" // root/task/images-models-options
|
||||
, "model": "plain" // "plain"
|
||||
, "gpu_ids": [0]
|
||||
|
||||
, "scale": 0 // broadcast to "netG" if SISR
|
||||
, "n_channels": 1 // broadcast to "datasets", 1 for grayscale, 3 for color
|
||||
, "n_channels_datasetload": 3 // broadcast to image training set
|
||||
, "sigma": [0, 50] // 15, 25, 50 for DnCNN | [0, 75] for FFDNet and FDnCNN
|
||||
, "sigma_test": 5 // 15, 25, 50 for DnCNN and ffdnet
|
||||
|
||||
, "path": {
|
||||
"root": "denoising" // "denoising" | "superresolution"
|
||||
, "pretrained_netG": null // path of pretrained model, if model from scratch type: null
|
||||
}
|
||||
|
||||
, "datasets": {
|
||||
"train": {
|
||||
"name": "train_dataset" // just name
|
||||
, "dataset_type": "ffdnet" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch"
|
||||
, "dataroot_H": "trainsets/web_images_train"// path of H training dataset
|
||||
, "num_patches_per_image": 20 // number of random patches of training image
|
||||
, "dataroot_L": "trainsets/simulations" // path of L training dataset, if using noisy H type: null
|
||||
, "H_size": 128 // patch size 40 | 64 | 96 | 128 | 192
|
||||
, "dataloader_shuffle": true
|
||||
, "dataloader_num_workers": 8
|
||||
, "dataloader_batch_size": 64 // batch size 1 | 16 | 32 | 48 | 64 | 128
|
||||
}
|
||||
, "test": {
|
||||
"name": "test_dataset" // just name
|
||||
, "dataset_type": "ffdnet" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch"
|
||||
, "dataroot_H": "testsets/web_images_test" // path of H testing dataset
|
||||
, "dataroot_L": "testsets/simulations" // path of L testing dataset
|
||||
}
|
||||
}
|
||||
|
||||
, "netG": {
|
||||
"net_type": "drunet" // "dncnn" | "fdncnn" | "ffdnet" | "srmd" | "dpsr" | "srresnet0" | "srresnet1" | "rrdbnet"
|
||||
, "in_nc": 1 // input channel number
|
||||
, "out_nc": 1 // ouput channel number
|
||||
, "nc": [64, 128, 256, 512] // 64 for "dncnn"
|
||||
, "nb": 4 // 17 for "dncnn", 20 for dncnn3, 16 for "srresnet"
|
||||
, "gc": 32 // unused
|
||||
, "ng": 2 // unused
|
||||
, "reduction": 16 // unused
|
||||
, "act_mode": "R" // "BR" for BN+ReLU | "R" for ReLU
|
||||
, "upsample_mode": "convtranspose" // "pixelshuffle" | "convtranspose" | "upconv"
|
||||
, "downsample_mode": "strideconv" // "strideconv" | "avgpool" | "maxpool"
|
||||
, "bias": false//
|
||||
, "init_type": "orthogonal" // "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform"
|
||||
, "init_bn_type": "uniform" // "uniform" | "constant"
|
||||
, "init_gain": 0.2
|
||||
}
|
||||
|
||||
, "train": {
|
||||
"epochs": 301 // number of epochs to train
|
||||
, "G_lossfn_type": "l1" // "l1" preferred | "l2sum" | "l2" | "ssim"
|
||||
, "G_lossfn_weight": 1.0 // default
|
||||
, "G_tvloss_weight": 0.1 // total variation weight
|
||||
|
||||
, "G_optimizer_type": "adam" // fixed, adam is enough
|
||||
, "G_optimizer_lr": 1e-4 // learning rate
|
||||
, "G_optimizer_clipgrad": null // unused
|
||||
|
||||
, "G_scheduler_type": "MultiStepLR" // "MultiStepLR" is enough
|
||||
, "G_scheduler_milestones": [640, 980, 1600, 1920, 2400, 4800, 6400, 9280]
|
||||
, "G_scheduler_gamma": 0.1 //
|
||||
|
||||
, "G_regularizer_orthstep": null // unused
|
||||
, "G_regularizer_clipstep": null // unused
|
||||
|
||||
// iteration (batch step) checkpoints
|
||||
, "checkpoint_test": 320 // for testing
|
||||
, "checkpoint_save": 1600 // for saving model
|
||||
, "checkpoint_print": 32 // for print
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue