Compute metrics when testing
This commit is contained in:
parent
0715eda4f6
commit
18475ea8eb
|
@ -41,7 +41,7 @@ class DatasetDrunetFineTune(data.Dataset):
|
||||||
contains one or more L representations of the H image.
|
contains one or more L representations of the H image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert os.path.isdir(opt['dataroot_H']), "No es dir"
|
assert os.path.isdir(opt['dataroot_H']), f"{opt['dataroot_H']} is not a directory"
|
||||||
self.paths_H = [f for f in os.listdir(opt['dataroot_H']) if os.path.isfile(os.path.join(opt['dataroot_H'],f))]
|
self.paths_H = [f for f in os.listdir(opt['dataroot_H']) if os.path.isfile(os.path.join(opt['dataroot_H'],f))]
|
||||||
#------------------------------------------------------------------------------------------------------
|
#------------------------------------------------------------------------------------------------------
|
||||||
# For the above step you can use util.get_image_paths(), but it goes recursevely throught the tree dirs
|
# For the above step you can use util.get_image_paths(), but it goes recursevely throught the tree dirs
|
||||||
|
|
|
@ -18,6 +18,27 @@ from utils.utils_dist import get_dist_info, init_dist
|
||||||
from data.select_dataset import define_Dataset
|
from data.select_dataset import define_Dataset
|
||||||
from models.select_model import define_Model
|
from models.select_model import define_Model
|
||||||
|
|
||||||
|
# OCR metrics
|
||||||
|
# First, must install Tesseract: https://tesseract-ocr.github.io/tessdoc/Installation.html
|
||||||
|
# Second, install CER/WER and tesseract python wrapper libraries
|
||||||
|
# pip install fastwer
|
||||||
|
# pip install pybind11
|
||||||
|
# pip install pytesseract
|
||||||
|
import pytesseract
|
||||||
|
import fastwer
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cer_wer(img_E, img_H):
|
||||||
|
# Transcribe ground-truth image to text
|
||||||
|
text_H = pytesseract.image_to_string(img_H).strip().replace('\n',' ')
|
||||||
|
|
||||||
|
# Transcribe estimated image to text
|
||||||
|
text_E = pytesseract.image_to_string(img_E).strip().replace('\n',' ')
|
||||||
|
|
||||||
|
cer = fastwer.score_sent(text_E, text_H, char_level=True)
|
||||||
|
wer = fastwer.score_sent(text_E, text_H)
|
||||||
|
|
||||||
|
return cer, wer
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# --------------------------------------------
|
# --------------------------------------------
|
||||||
|
@ -96,6 +117,7 @@ def main(json_path='options/test_drunet.json'):
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
"""
|
"""
|
||||||
L_paths = util.get_image_paths(opt['datasets']['test']['dataroot_L'])
|
L_paths = util.get_image_paths(opt['datasets']['test']['dataroot_L'])
|
||||||
|
H_paths = util.get_image_paths(opt['datasets']['test']['dataroot_H'])
|
||||||
noise_sigma = opt['datasets']['test']['sigma_test']
|
noise_sigma = opt['datasets']['test']['sigma_test']
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
@ -103,12 +125,15 @@ def main(json_path='options/test_drunet.json'):
|
||||||
# Step--4 (main test)
|
# Step--4 (main test)
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
'''
|
'''
|
||||||
# avg_psnr = 0.0
|
avg_psnr = 0.0
|
||||||
# avg_ssim = 0.0
|
avg_ssim = 0.0
|
||||||
# idx = 0
|
avg_edgeJaccard = 0.0
|
||||||
|
avg_cer = 0.0
|
||||||
|
avg_wer = 0.0
|
||||||
|
idx = 0
|
||||||
|
|
||||||
for L_path in L_paths:
|
for L_path, H_path in zip(L_paths,H_paths):
|
||||||
# idx += 1
|
idx += 1
|
||||||
image_name_ext = os.path.basename(L_path)
|
image_name_ext = os.path.basename(L_path)
|
||||||
img_name, ext = os.path.splitext(image_name_ext)
|
img_name, ext = os.path.splitext(image_name_ext)
|
||||||
|
|
||||||
|
@ -118,7 +143,7 @@ def main(json_path='options/test_drunet.json'):
|
||||||
logger.info('Creating inference on test image...')
|
logger.info('Creating inference on test image...')
|
||||||
|
|
||||||
# Load image
|
# Load image
|
||||||
img_L_original = util.imread_uint(L_path, n_channels=3)
|
img_L_original = util.imread_uint(L_path, n_channels=3)[50:-50,100:-100,:]
|
||||||
img_L = img_L_original[:,:,:2]
|
img_L = img_L_original[:,:,:2]
|
||||||
img_L = util.uint2single(img_L)
|
img_L = util.uint2single(img_L)
|
||||||
img_L = util.single2tensor4(img_L)
|
img_L = util.single2tensor4(img_L)
|
||||||
|
@ -150,26 +175,36 @@ def main(json_path='options/test_drunet.json'):
|
||||||
|
|
||||||
logger.info(f'Inference of {img_name} completed. Saved at {img_dir}.')
|
logger.info(f'Inference of {img_name} completed. Saved at {img_dir}.')
|
||||||
|
|
||||||
# -----------------------
|
# Load H image and compute metrics
|
||||||
# calculate PSNR
|
img_H = util.imread_uint(H_path, n_channels=3)
|
||||||
# -----------------------
|
if img_H.ndim == 3:
|
||||||
# current_psnr = util.calculate_psnr(E_img, H_img)
|
img_H = np.mean(img_H, axis=2)
|
||||||
|
img_H = img_H.astype('uint8')
|
||||||
|
|
||||||
# -----------------------
|
# ----------------------------------------
|
||||||
# calculate SSIM
|
# compute PSNR, SSIM, edgeJaccard and CER
|
||||||
# -----------------------
|
# ----------------------------------------
|
||||||
# current_ssim = util.calculate_ssim(E_img, H_img)
|
current_psnr = util.calculate_psnr(img_E, img_H)
|
||||||
|
current_ssim = util.calculate_ssim(img_E, img_H)
|
||||||
|
current_edgeJaccard = util.calculate_edge_jaccard(img_E, img_H)
|
||||||
|
current_cer, current_wer = calculate_cer_wer(img_E, img_H)
|
||||||
|
|
||||||
# logger.info('{:->4d}--> {:>10s} | PSNR = {:<4.2f}dB, SSIM = {:<4.2f}'.format(idx, image_name_ext, current_psnr, current_ssim))
|
logger.info('{:->4d}--> {:>10s} | PSNR = {:<4.2f}dB ; SSIM = {:.3f} ; edgeJaccard = {:.3f} ; CER = {:.3f}% ; WER = {:.3f}%'.format(idx, image_name_ext, current_psnr, current_ssim, current_edgeJaccard, current_cer, current_wer))
|
||||||
|
|
||||||
# avg_psnr += current_psnr
|
avg_psnr += current_psnr
|
||||||
# avg_ssim += current_ssim
|
avg_ssim += current_ssim
|
||||||
|
avg_edgeJaccard += current_edgeJaccard
|
||||||
|
avg_cer += current_cer
|
||||||
|
avg_wer += current_wer
|
||||||
|
|
||||||
# avg_psnr = avg_psnr / idx
|
avg_psnr = avg_psnr / idx
|
||||||
# avg_ssim = avg_ssim / idx
|
avg_ssim = avg_ssim / idx
|
||||||
|
avg_edgeJaccard = avg_edgeJaccard / idx
|
||||||
|
avg_cer = avg_cer / idx
|
||||||
|
avg_wer = avg_wer / idx
|
||||||
|
|
||||||
# testing log
|
# Average log
|
||||||
# logger.info('Average PSNR : {:<.2f}dB, Average SSIM : {:<4.2f}\n'.format(avg_psnr, avg_ssim))
|
logger.info('[Average metrics] PSNR : {:<4.2f}dB, SSIM = {:.3f} : edgeJaccard = {:.3f} : CER = {:.3f}% : WER = {:.3f}%'.format(avg_psnr, avg_ssim, avg_edgeJaccard, avg_cer, avg_wer))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue