Rm train metric logs,. Added edgeJaccard 25 trials optuna results
This commit is contained in:
parent
a51e018ec0
commit
d47ab27c4a
|
@ -13,7 +13,7 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
, "optuna":{
|
, "optuna":{
|
||||||
"n_trials": 20
|
"n_trials": 25
|
||||||
,"trial_epochs": 10 // Maximum epochs per trial
|
,"trial_epochs": 10 // Maximum epochs per trial
|
||||||
,"metric": "edgeJaccard" // "edgeJaccard" | "PSNR" | "SSIM" | "CER" | "MSE"
|
,"metric": "edgeJaccard" // "edgeJaccard" | "PSNR" | "SSIM" | "CER" | "MSE"
|
||||||
|
|
||||||
|
@ -23,8 +23,8 @@
|
||||||
"train": {
|
"train": {
|
||||||
"name": "train_dataset" // just name
|
"name": "train_dataset" // just name
|
||||||
, "dataset_type": "ffdnet" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch"
|
, "dataset_type": "ffdnet" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch"
|
||||||
, "dataroot_H": "optuna_hparams/optuna_subset/train/ground_truth" // path of H training dataset
|
, "dataroot_H": "trainsets/ground-truth" // path of H training dataset
|
||||||
, "dataroot_L": "optuna_hparams/optuna_subset/train/simulations" // path of L training dataset, if using noisy H type: null
|
, "dataroot_L": "trainsets/simulations" // path of L training dataset, if using noisy H type: null
|
||||||
, "sigma": [0, 15] // 15, 25, 50 for DnCNN | [0, 75] for FFDNet and FDnCNN
|
, "sigma": [0, 15] // 15, 25, 50 for DnCNN | [0, 75] for FFDNet and FDnCNN
|
||||||
, "num_patches_per_image": 21 // number of random patches of training image
|
, "num_patches_per_image": 21 // number of random patches of training image
|
||||||
, "H_size": 256 // patch size 40 | 64 | 96 | 128 | 192
|
, "H_size": 256 // patch size 40 | 64 | 96 | 128 | 192
|
||||||
|
@ -35,8 +35,8 @@
|
||||||
, "test": {
|
, "test": {
|
||||||
"name": "test_dataset" // just name
|
"name": "test_dataset" // just name
|
||||||
, "dataset_type": "ffdnet" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch"
|
, "dataset_type": "ffdnet" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch"
|
||||||
, "dataroot_H": "optuna_hparams/optuna_subset/val/ground_truth" // path of H testing dataset
|
, "dataroot_H": "testsets/ground-truth" // path of H testing dataset
|
||||||
, "dataroot_L": "optuna_hparams/optuna_subset/val/simulations" // path of L testing dataset
|
, "dataroot_L": "testsets/simulations" // path of L testing dataset
|
||||||
, "sigma_test": 10 // 15, 25, 50 for DnCNN and ffdnet
|
, "sigma_test": 10 // 15, 25, 50 for DnCNN and ffdnet
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,7 +95,7 @@ for phase, dataset_opt in opt['datasets'].items():
|
||||||
indexes = torch.randperm(len(test_set))[:len(test_set)//2]
|
indexes = torch.randperm(len(test_set))[:len(test_set)//2]
|
||||||
test_set = Subset(test_set, indexes)
|
test_set = Subset(test_set, indexes)
|
||||||
val_loader = DataLoader(test_set, batch_size=1,
|
val_loader = DataLoader(test_set, batch_size=1,
|
||||||
shuffle=True, num_workers=1,
|
shuffle=False, num_workers=1,
|
||||||
drop_last=False, pin_memory=True)
|
drop_last=False, pin_memory=True)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Phase [%s] is not recognized." % phase)
|
raise NotImplementedError("Phase [%s] is not recognized." % phase)
|
||||||
|
@ -106,7 +106,7 @@ logger.info(message)
|
||||||
dataset = {'train':train_loader, 'val':val_loader}
|
dataset = {'train':train_loader, 'val':val_loader}
|
||||||
|
|
||||||
# Define model function with optuna hyperparameters
|
# Define model function with optuna hyperparameters
|
||||||
def define_model(trial, opt):
|
def define_model(opt):
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = define_Model(opt)
|
model = define_Model(opt)
|
||||||
|
@ -152,10 +152,7 @@ def train_model(trial, model, dataset, metric_dict, num_epochs=25):
|
||||||
metric = metric_dict['func']
|
metric = metric_dict['func']
|
||||||
metric_direction = metric_dict['direction']
|
metric_direction = metric_dict['direction']
|
||||||
|
|
||||||
# Copy model weights to get best weights register
|
best_metric = -1e6*(metric_direction=='maximize') + 1e6*(metric_direction=='minimize')
|
||||||
# best_model_wts = copy.deepcopy(model.state_dict())
|
|
||||||
|
|
||||||
best_metric = 0*(metric_direction=='maximize') + 1e6*(metric_direction=='minimize')
|
|
||||||
|
|
||||||
current_step = 0
|
current_step = 0
|
||||||
|
|
||||||
|
@ -167,7 +164,7 @@ def train_model(trial, model, dataset, metric_dict, num_epochs=25):
|
||||||
|
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
|
|
||||||
epoch_metric = 0.0
|
# epoch_metric = 0.0
|
||||||
|
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
# Training phase
|
# Training phase
|
||||||
|
@ -192,31 +189,24 @@ def train_model(trial, model, dataset, metric_dict, num_epochs=25):
|
||||||
model.optimize_parameters(current_step)
|
model.optimize_parameters(current_step)
|
||||||
|
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
# 4) training information (loss)
|
# 4) training information (loss and metric)
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
|
|
||||||
logs = model.current_log()
|
# visuals = model.current_visuals()
|
||||||
batch_loss = logs['G_loss'] # get batch loss / iter loss
|
# E_visual = visuals['E']
|
||||||
|
# E_img = util.tensor2uint(E_visual)
|
||||||
|
# H_visual = visuals['H']
|
||||||
|
# H_img = util.tensor2uint(H_visual)
|
||||||
|
|
||||||
visuals = model.current_visuals()
|
# epoch_metric += metric(H_img, E_img)
|
||||||
E_visual = visuals['E']
|
|
||||||
E_img = util.tensor2uint(E_visual)
|
|
||||||
H_visual = visuals['H']
|
|
||||||
H_img = util.tensor2uint(H_visual)
|
|
||||||
|
|
||||||
epoch_metric += metric(H_img, E_img)
|
|
||||||
|
|
||||||
epoch_loss += model.current_log()['G_loss']
|
epoch_loss += model.current_log()['G_loss']
|
||||||
|
|
||||||
# Train loss and metric
|
# Train loss and metric
|
||||||
avg_train_loss = epoch_loss/train_size
|
avg_train_loss = epoch_loss/train_size
|
||||||
avg_train_metric = epoch_metric/train_size
|
# avg_train_metric = epoch_metric/train_size
|
||||||
|
|
||||||
message_train = f'\nepoch:{epoch+1}/{num_epochs}\n'+'-'*14+'\ntrain loss: {:.3e}, train {}: {:.3f}\n'.format(
|
message_train = f'\nepoch:{epoch+1}/{num_epochs}\n'+'-'*14+'\ntrain loss: {:.3e}\n'.format(avg_train_loss)
|
||||||
avg_train_loss,
|
|
||||||
metric_dict['name'],
|
|
||||||
avg_train_metric
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
|
@ -243,9 +233,8 @@ def train_model(trial, model, dataset, metric_dict, num_epochs=25):
|
||||||
current_loss = model.G_lossfn(torch.reshape(E_visual,(1,1,sizes[1],sizes[2])),
|
current_loss = model.G_lossfn(torch.reshape(E_visual,(1,1,sizes[1],sizes[2])),
|
||||||
torch.reshape(H_visual,(1,1,sizes[1],sizes[2])))
|
torch.reshape(H_visual,(1,1,sizes[1],sizes[2])))
|
||||||
|
|
||||||
val_metric += metric(H_img, E_img)
|
|
||||||
|
|
||||||
avg_val_loss += current_loss
|
avg_val_loss += current_loss
|
||||||
|
val_metric += metric(H_img, E_img)
|
||||||
|
|
||||||
# Val loss and metric
|
# Val loss and metric
|
||||||
avg_val_loss = avg_val_loss/idx
|
avg_val_loss = avg_val_loss/idx
|
||||||
|
@ -258,8 +247,13 @@ def train_model(trial, model, dataset, metric_dict, num_epochs=25):
|
||||||
# Write epoch log
|
# Write epoch log
|
||||||
logger.info(message_train + message_val +'-'*14)
|
logger.info(message_train + message_val +'-'*14)
|
||||||
|
|
||||||
# Update if validation metric is better
|
# Update if validation metric is better (lower when minimizing, greater when maximizing)
|
||||||
if avg_val_metric > best_metric:
|
maximizing = ( (avg_val_metric > best_metric) and metric_dict['direction'] == 'maximize')
|
||||||
|
minimizing = ( (avg_val_metric < best_metric) and metric_dict['direction'] == 'minimize')
|
||||||
|
|
||||||
|
val_metric_is_better = maximizing or minimizing
|
||||||
|
|
||||||
|
if val_metric_is_better:
|
||||||
best_metric = avg_val_metric
|
best_metric = avg_val_metric
|
||||||
# best_model_wts = copy.deepcopy(model.state_dict())
|
# best_model_wts = copy.deepcopy(model.state_dict())
|
||||||
|
|
||||||
|
@ -284,7 +278,7 @@ def train_model(trial, model, dataset, metric_dict, num_epochs=25):
|
||||||
def objective(trial):
|
def objective(trial):
|
||||||
|
|
||||||
# Set learning rate suggestions for trial
|
# Set learning rate suggestions for trial
|
||||||
trial_lr = trial.suggest_loguniform("lr", 1e-5, 1e-1)
|
trial_lr = trial.suggest_loguniform("lr", 1e-6, 1e-1)
|
||||||
opt['train']['G_optimizaer_lr'] = trial_lr
|
opt['train']['G_optimizaer_lr'] = trial_lr
|
||||||
|
|
||||||
trial_tvweight = trial.suggest_loguniform("tv_weight", 1e-7, 1e-2)
|
trial_tvweight = trial.suggest_loguniform("tv_weight", 1e-7, 1e-2)
|
||||||
|
@ -297,7 +291,7 @@ def objective(trial):
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
# Generate the model and optimizers
|
# Generate the model and optimizers
|
||||||
model = define_model(trial, opt)
|
model = define_model(opt)
|
||||||
|
|
||||||
# Select metric specified at options
|
# Select metric specified at options
|
||||||
metric_dict = define_metric(opt['optuna']['metric'])
|
metric_dict = define_metric(opt['optuna']['metric'])
|
||||||
|
@ -327,9 +321,12 @@ def save_optuna_info(study):
|
||||||
# Save page for optimization history
|
# Save page for optimization history
|
||||||
fig = optuna.visualization.plot_optimization_history(study)
|
fig = optuna.visualization.plot_optimization_history(study)
|
||||||
fig.write_html(os.path.join(root_dir,'optuna_plot_optimization_history.html'))
|
fig.write_html(os.path.join(root_dir,'optuna_plot_optimization_history.html'))
|
||||||
# Save page for optimization history
|
# Save page for intermediate values plot
|
||||||
fig = optuna.visualization.plot_intermediate_values(study)
|
fig = optuna.visualization.plot_intermediate_values(study)
|
||||||
fig.write_html(os.path.join(root_dir,'optuna_plot_intermediate_values.html'))
|
fig.write_html(os.path.join(root_dir,'optuna_plot_intermediate_values.html'))
|
||||||
|
# Save page for parallel coordinate plot
|
||||||
|
fig = optuna.visualization.plot_parallel_coordinate(study)
|
||||||
|
fig.write_html(os.path.join(root_dir,'optuna_plot_parallel_coordinate.html'))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -344,7 +341,7 @@ sampler = optuna.samplers.TPESampler()
|
||||||
study = optuna.create_study(
|
study = optuna.create_study(
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pruner=optuna.pruners.MedianPruner(
|
pruner=optuna.pruners.MedianPruner(
|
||||||
n_startup_trials=3, n_warmup_steps=5, interval_steps=3
|
n_startup_trials=5, n_warmup_steps=3, interval_steps=3
|
||||||
),
|
),
|
||||||
direction=metric_dict['direction'])
|
direction=metric_dict['direction'])
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue