From 69aff57894618b7379ea900b72a7bf803ca9235b Mon Sep 17 00:00:00 2001 From: Emilio Martinez Date: Wed, 29 Mar 2023 19:33:45 -0300 Subject: [PATCH] Training settings for abs value tempest --- KAIR/data/dataset_ffdnet.py | 16 ++++++++++++++-- KAIR/models/model_base.py | 3 ++- KAIR/options/train_drunet.json | 12 ++++++------ 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/KAIR/data/dataset_ffdnet.py b/KAIR/data/dataset_ffdnet.py index eff2890..f3a8b31 100644 --- a/KAIR/data/dataset_ffdnet.py +++ b/KAIR/data/dataset_ffdnet.py @@ -30,8 +30,10 @@ class DatasetFFDNet(data.Dataset): # ------------------------------------- # get the path of H, return None if input is None # ------------------------------------- - self.paths_H = util.get_image_paths(opt['dataroot_H']) - self.paths_L = util.get_image_paths(opt['dataroot_L']) + self.paths_H = util.get_image_paths(opt['dataroot_H'])[:50] # Edit: overfittear con las primeras 50 imagenes + self.paths_L = util.get_image_paths(opt['dataroot_L'])[:50] # Edit: las primeras 9 imagenes pertenecen a test + + # print('\nNum patches:',self.num_patches_per_image,'\n') if self.opt['phase'] == 'train': listOfLists = [list(itertools.repeat(path, self.num_patches_per_image)) for path in self.paths_H] self.paths_H = list(itertools.chain.from_iterable(listOfLists)) @@ -81,6 +83,11 @@ class DatasetFFDNet(data.Dataset): # Get the patch from the simulation patch_L = img_L[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] + # Get module of complex image + patch_L = patch_L.astype('float') + patch_L = np.abs(patch_L[:,:,0]+1j*patch_L[:,:,1]).astype('uint8') + + # # Commented augmentation with rotating because of TMDS encoding # --------------------------------- @@ -119,6 +126,11 @@ class DatasetFFDNet(data.Dataset): img_H = img_H[:,:,np.newaxis] img_H = util.uint2single(img_H) + # Get module of complex image + img_L = img_L.astype('float') + img_L = np.abs(img_L[:,:,0]+1j*img_L[:,:,1]).astype('uint8') + img_L = img_L[:,:,np.newaxis] + np.random.seed(seed=0) img_L = img_L + np.random.normal(0, self.sigma_test/255.0, img_L.shape) noise_level = torch.FloatTensor([self.sigma_test/255.0]) diff --git a/KAIR/models/model_base.py b/KAIR/models/model_base.py index 0ae3bce..965ba10 100644 --- a/KAIR/models/model_base.py +++ b/KAIR/models/model_base.py @@ -9,7 +9,8 @@ class ModelBase(): def __init__(self, opt): self.opt = opt # opt self.save_dir = opt['path']['models'] # save models - self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') + # self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') + self.device = torch.device('cuda' if len(opt['gpu_ids']) != 0 else 'cpu') self.is_train = opt['is_train'] # training or not self.schedulers = [] # schedulers diff --git a/KAIR/options/train_drunet.json b/KAIR/options/train_drunet.json index c498a9e..ac83728 100644 --- a/KAIR/options/train_drunet.json +++ b/KAIR/options/train_drunet.json @@ -4,7 +4,7 @@ , "gpu_ids": [0] , "scale": 1 // broadcast to "netG" if SISR - , "n_channels": 2 // broadcast to "datasets", 1 for grayscale, 3 for color + , "n_channels": 1 // broadcast to "datasets", 1 for grayscale, 3 for color , "n_channels_datasetload": 3 // broadcast to image training set , "sigma": [0, 50] // 15, 25, 50 for DnCNN | [0, 75] for FFDNet and FDnCNN , "sigma_test": 25 // 15, 25, 50 for DnCNN and ffdnet @@ -36,7 +36,7 @@ , "netG": { "net_type": "drunet" // "dncnn" | "fdncnn" | "ffdnet" | "srmd" | "dpsr" | "srresnet0" | "srresnet1" | "rrdbnet" - , "in_nc": 2 // input channel number + , "in_nc": 1 // input channel number , "out_nc": 1 // ouput channel number , "nc": [64, 128, 256, 512] // 64 for "dncnn" , "nb": 4 // 17 for "dncnn", 20 for dncnn3, 16 for "srresnet" @@ -54,7 +54,7 @@ , "train": { "epochs": 1000 // number of epochs to train - , "G_lossfn_type": "l1" // "l1" preferred | "l2sum" | "l2" | "ssim" + , "G_lossfn_type": "tv" // "l1" preferred | "l2sum" | "l2" | "ssim" , "G_lossfn_weight": 1.0 // default , "G_tvloss_weight": 1.0 // total variation weight @@ -70,8 +70,8 @@ , "G_regularizer_clipstep": null // unused // iteration (batch step) checkpoints - , "checkpoint_test": 1000 // for testing - , "checkpoint_save": 1000 // for saving model - , "checkpoint_print": 100 // for print + , "checkpoint_test": 500 // for testing + , "checkpoint_save": 780 // for saving model + , "checkpoint_print": 16 // for print } }