Training settings for abs value tempest

This commit is contained in:
Emilio Martinez 2023-03-29 19:33:45 -03:00
parent c00fd3efd2
commit 69aff57894
3 changed files with 22 additions and 9 deletions

View File

@ -30,8 +30,10 @@ class DatasetFFDNet(data.Dataset):
# ------------------------------------- # -------------------------------------
# get the path of H, return None if input is None # get the path of H, return None if input is None
# ------------------------------------- # -------------------------------------
self.paths_H = util.get_image_paths(opt['dataroot_H']) self.paths_H = util.get_image_paths(opt['dataroot_H'])[:50] # Edit: overfittear con las primeras 50 imagenes
self.paths_L = util.get_image_paths(opt['dataroot_L']) self.paths_L = util.get_image_paths(opt['dataroot_L'])[:50] # Edit: las primeras 9 imagenes pertenecen a test
# print('\nNum patches:',self.num_patches_per_image,'\n')
if self.opt['phase'] == 'train': if self.opt['phase'] == 'train':
listOfLists = [list(itertools.repeat(path, self.num_patches_per_image)) for path in self.paths_H] listOfLists = [list(itertools.repeat(path, self.num_patches_per_image)) for path in self.paths_H]
self.paths_H = list(itertools.chain.from_iterable(listOfLists)) self.paths_H = list(itertools.chain.from_iterable(listOfLists))
@ -81,6 +83,11 @@ class DatasetFFDNet(data.Dataset):
# Get the patch from the simulation # Get the patch from the simulation
patch_L = img_L[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] patch_L = img_L[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :]
# Get module of complex image
patch_L = patch_L.astype('float')
patch_L = np.abs(patch_L[:,:,0]+1j*patch_L[:,:,1]).astype('uint8')
# # Commented augmentation with rotating because of TMDS encoding # # Commented augmentation with rotating because of TMDS encoding
# --------------------------------- # ---------------------------------
@ -119,6 +126,11 @@ class DatasetFFDNet(data.Dataset):
img_H = img_H[:,:,np.newaxis] img_H = img_H[:,:,np.newaxis]
img_H = util.uint2single(img_H) img_H = util.uint2single(img_H)
# Get module of complex image
img_L = img_L.astype('float')
img_L = np.abs(img_L[:,:,0]+1j*img_L[:,:,1]).astype('uint8')
img_L = img_L[:,:,np.newaxis]
np.random.seed(seed=0) np.random.seed(seed=0)
img_L = img_L + np.random.normal(0, self.sigma_test/255.0, img_L.shape) img_L = img_L + np.random.normal(0, self.sigma_test/255.0, img_L.shape)
noise_level = torch.FloatTensor([self.sigma_test/255.0]) noise_level = torch.FloatTensor([self.sigma_test/255.0])

View File

@ -9,7 +9,8 @@ class ModelBase():
def __init__(self, opt): def __init__(self, opt):
self.opt = opt # opt self.opt = opt # opt
self.save_dir = opt['path']['models'] # save models self.save_dir = opt['path']['models'] # save models
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') # self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
self.device = torch.device('cuda' if len(opt['gpu_ids']) != 0 else 'cpu')
self.is_train = opt['is_train'] # training or not self.is_train = opt['is_train'] # training or not
self.schedulers = [] # schedulers self.schedulers = [] # schedulers

View File

@ -4,7 +4,7 @@
, "gpu_ids": [0] , "gpu_ids": [0]
, "scale": 1 // broadcast to "netG" if SISR , "scale": 1 // broadcast to "netG" if SISR
, "n_channels": 2 // broadcast to "datasets", 1 for grayscale, 3 for color , "n_channels": 1 // broadcast to "datasets", 1 for grayscale, 3 for color
, "n_channels_datasetload": 3 // broadcast to image training set , "n_channels_datasetload": 3 // broadcast to image training set
, "sigma": [0, 50] // 15, 25, 50 for DnCNN | [0, 75] for FFDNet and FDnCNN , "sigma": [0, 50] // 15, 25, 50 for DnCNN | [0, 75] for FFDNet and FDnCNN
, "sigma_test": 25 // 15, 25, 50 for DnCNN and ffdnet , "sigma_test": 25 // 15, 25, 50 for DnCNN and ffdnet
@ -36,7 +36,7 @@
, "netG": { , "netG": {
"net_type": "drunet" // "dncnn" | "fdncnn" | "ffdnet" | "srmd" | "dpsr" | "srresnet0" | "srresnet1" | "rrdbnet" "net_type": "drunet" // "dncnn" | "fdncnn" | "ffdnet" | "srmd" | "dpsr" | "srresnet0" | "srresnet1" | "rrdbnet"
, "in_nc": 2 // input channel number , "in_nc": 1 // input channel number
, "out_nc": 1 // ouput channel number , "out_nc": 1 // ouput channel number
, "nc": [64, 128, 256, 512] // 64 for "dncnn" , "nc": [64, 128, 256, 512] // 64 for "dncnn"
, "nb": 4 // 17 for "dncnn", 20 for dncnn3, 16 for "srresnet" , "nb": 4 // 17 for "dncnn", 20 for dncnn3, 16 for "srresnet"
@ -54,7 +54,7 @@
, "train": { , "train": {
"epochs": 1000 // number of epochs to train "epochs": 1000 // number of epochs to train
, "G_lossfn_type": "l1" // "l1" preferred | "l2sum" | "l2" | "ssim" , "G_lossfn_type": "tv" // "l1" preferred | "l2sum" | "l2" | "ssim"
, "G_lossfn_weight": 1.0 // default , "G_lossfn_weight": 1.0 // default
, "G_tvloss_weight": 1.0 // total variation weight , "G_tvloss_weight": 1.0 // total variation weight
@ -70,8 +70,8 @@
, "G_regularizer_clipstep": null // unused , "G_regularizer_clipstep": null // unused
// iteration (batch step) checkpoints // iteration (batch step) checkpoints
, "checkpoint_test": 1000 // for testing , "checkpoint_test": 500 // for testing
, "checkpoint_save": 1000 // for saving model , "checkpoint_save": 780 // for saving model
, "checkpoint_print": 100 // for print , "checkpoint_print": 16 // for print
} }
} }