From 18475ea8eb6d670c6eae6f27713404a133a67d20 Mon Sep 17 00:00:00 2001 From: Emilio Martinez Date: Tue, 23 Jan 2024 14:26:09 -0300 Subject: [PATCH] Compute metrics when testing --- .../data/dataset_deeptempest_finetuning.py | 2 +- end-to-end/main_test_drunet.py | 77 ++++++++++++++----- 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/end-to-end/data/dataset_deeptempest_finetuning.py b/end-to-end/data/dataset_deeptempest_finetuning.py index 2203710..afea870 100644 --- a/end-to-end/data/dataset_deeptempest_finetuning.py +++ b/end-to-end/data/dataset_deeptempest_finetuning.py @@ -41,7 +41,7 @@ class DatasetDrunetFineTune(data.Dataset): contains one or more L representations of the H image. """ - assert os.path.isdir(opt['dataroot_H']), "No es dir" + assert os.path.isdir(opt['dataroot_H']), f"{opt['dataroot_H']} is not a directory" self.paths_H = [f for f in os.listdir(opt['dataroot_H']) if os.path.isfile(os.path.join(opt['dataroot_H'],f))] #------------------------------------------------------------------------------------------------------ # For the above step you can use util.get_image_paths(), but it goes recursevely throught the tree dirs diff --git a/end-to-end/main_test_drunet.py b/end-to-end/main_test_drunet.py index f7a62b5..b651b26 100644 --- a/end-to-end/main_test_drunet.py +++ b/end-to-end/main_test_drunet.py @@ -18,6 +18,27 @@ from utils.utils_dist import get_dist_info, init_dist from data.select_dataset import define_Dataset from models.select_model import define_Model +# OCR metrics +# First, must install Tesseract: https://tesseract-ocr.github.io/tessdoc/Installation.html +# Second, install CER/WER and tesseract python wrapper libraries +# pip install fastwer +# pip install pybind11 +# pip install pytesseract +import pytesseract +import fastwer + + +def calculate_cer_wer(img_E, img_H): + # Transcribe ground-truth image to text + text_H = pytesseract.image_to_string(img_H).strip().replace('\n',' ') + + # Transcribe estimated image to text + text_E = pytesseract.image_to_string(img_E).strip().replace('\n',' ') + + cer = fastwer.score_sent(text_E, text_H, char_level=True) + wer = fastwer.score_sent(text_E, text_H) + + return cer, wer ''' # -------------------------------------------- @@ -96,6 +117,7 @@ def main(json_path='options/test_drunet.json'): # ---------------------------------------- """ L_paths = util.get_image_paths(opt['datasets']['test']['dataroot_L']) + H_paths = util.get_image_paths(opt['datasets']['test']['dataroot_H']) noise_sigma = opt['datasets']['test']['sigma_test'] ''' @@ -103,12 +125,15 @@ def main(json_path='options/test_drunet.json'): # Step--4 (main test) # ---------------------------------------- ''' - # avg_psnr = 0.0 - # avg_ssim = 0.0 - # idx = 0 + avg_psnr = 0.0 + avg_ssim = 0.0 + avg_edgeJaccard = 0.0 + avg_cer = 0.0 + avg_wer = 0.0 + idx = 0 - for L_path in L_paths: - # idx += 1 + for L_path, H_path in zip(L_paths,H_paths): + idx += 1 image_name_ext = os.path.basename(L_path) img_name, ext = os.path.splitext(image_name_ext) @@ -118,7 +143,7 @@ def main(json_path='options/test_drunet.json'): logger.info('Creating inference on test image...') # Load image - img_L_original = util.imread_uint(L_path, n_channels=3) + img_L_original = util.imread_uint(L_path, n_channels=3)[50:-50,100:-100,:] img_L = img_L_original[:,:,:2] img_L = util.uint2single(img_L) img_L = util.single2tensor4(img_L) @@ -150,26 +175,36 @@ def main(json_path='options/test_drunet.json'): logger.info(f'Inference of {img_name} completed. Saved at {img_dir}.') - # ----------------------- - # calculate PSNR - # ----------------------- - # current_psnr = util.calculate_psnr(E_img, H_img) + # Load H image and compute metrics + img_H = util.imread_uint(H_path, n_channels=3) + if img_H.ndim == 3: + img_H = np.mean(img_H, axis=2) + img_H = img_H.astype('uint8') - # ----------------------- - # calculate SSIM - # ----------------------- - # current_ssim = util.calculate_ssim(E_img, H_img) + # ---------------------------------------- + # compute PSNR, SSIM, edgeJaccard and CER + # ---------------------------------------- + current_psnr = util.calculate_psnr(img_E, img_H) + current_ssim = util.calculate_ssim(img_E, img_H) + current_edgeJaccard = util.calculate_edge_jaccard(img_E, img_H) + current_cer, current_wer = calculate_cer_wer(img_E, img_H) - # logger.info('{:->4d}--> {:>10s} | PSNR = {:<4.2f}dB, SSIM = {:<4.2f}'.format(idx, image_name_ext, current_psnr, current_ssim)) + logger.info('{:->4d}--> {:>10s} | PSNR = {:<4.2f}dB ; SSIM = {:.3f} ; edgeJaccard = {:.3f} ; CER = {:.3f}% ; WER = {:.3f}%'.format(idx, image_name_ext, current_psnr, current_ssim, current_edgeJaccard, current_cer, current_wer)) - # avg_psnr += current_psnr - # avg_ssim += current_ssim + avg_psnr += current_psnr + avg_ssim += current_ssim + avg_edgeJaccard += current_edgeJaccard + avg_cer += current_cer + avg_wer += current_wer - # avg_psnr = avg_psnr / idx - # avg_ssim = avg_ssim / idx + avg_psnr = avg_psnr / idx + avg_ssim = avg_ssim / idx + avg_edgeJaccard = avg_edgeJaccard / idx + avg_cer = avg_cer / idx + avg_wer = avg_wer / idx - # testing log - # logger.info('Average PSNR : {:<.2f}dB, Average SSIM : {:<4.2f}\n'.format(avg_psnr, avg_ssim)) + # Average log + logger.info('[Average metrics] PSNR : {:<4.2f}dB, SSIM = {:.3f} : edgeJaccard = {:.3f} : CER = {:.3f}% : WER = {:.3f}%'.format(avg_psnr, avg_ssim, avg_edgeJaccard, avg_cer, avg_wer)) if __name__ == '__main__': main()