94 lines
3.5 KiB
Python
94 lines
3.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import models.basicblock as B
|
|
import numpy as np
|
|
|
|
'''
|
|
# ====================
|
|
# Residual U-Net
|
|
# ====================
|
|
citation:
|
|
@article{zhang2020plug,
|
|
title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
|
|
author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
|
|
journal={arXiv preprint},
|
|
year={2020}
|
|
}
|
|
# ====================
|
|
'''
|
|
|
|
|
|
class UNetRes(nn.Module):
|
|
def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=True):
|
|
super(UNetRes, self).__init__()
|
|
|
|
self.m_head = B.conv(in_nc, nc[0], bias=bias, mode='C')
|
|
|
|
# downsample
|
|
if downsample_mode == 'avgpool':
|
|
downsample_block = B.downsample_avgpool
|
|
elif downsample_mode == 'maxpool':
|
|
downsample_block = B.downsample_maxpool
|
|
elif downsample_mode == 'strideconv':
|
|
downsample_block = B.downsample_strideconv
|
|
else:
|
|
raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
|
|
|
|
self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=bias, mode='2'))
|
|
self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=bias, mode='2'))
|
|
self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=bias, mode='2'))
|
|
|
|
self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)])
|
|
|
|
# upsample
|
|
if upsample_mode == 'upconv':
|
|
upsample_block = B.upsample_upconv
|
|
elif upsample_mode == 'pixelshuffle':
|
|
upsample_block = B.upsample_pixelshuffle
|
|
elif upsample_mode == 'convtranspose':
|
|
upsample_block = B.upsample_convtranspose
|
|
else:
|
|
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
|
|
|
self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=bias, mode='2'), *[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)])
|
|
self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=bias, mode='2'), *[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)])
|
|
self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=bias, mode='2'), *[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)])
|
|
|
|
self.m_tail = B.conv(nc[0], out_nc, bias=bias, mode='C')
|
|
|
|
def forward(self, x0):
|
|
|
|
# Resolve upsampling size issues with padding
|
|
h, w = x0.size()[-2:]
|
|
paddingBottom = int(np.ceil(h/8)*8-h)
|
|
paddingRight = int(np.ceil(w/8)*8-w)
|
|
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
|
|
|
# Forward UNet
|
|
|
|
x1 = self.m_head(x0)
|
|
x2 = self.m_down1(x1)
|
|
x3 = self.m_down2(x2)
|
|
x4 = self.m_down3(x3)
|
|
x = self.m_body(x4)
|
|
x = self.m_up3(x+x4)
|
|
x = self.m_up2(x+x3)
|
|
x = self.m_up1(x+x2)
|
|
x = self.m_tail(x+x1)
|
|
|
|
# Crop result to original size
|
|
x = x[..., :h, :w]
|
|
|
|
return x
|
|
|
|
|
|
if __name__ == '__main__':
|
|
x = torch.rand(1,3,256,256)
|
|
net = UNetRes()
|
|
net.eval()
|
|
with torch.no_grad():
|
|
y = net(x)
|
|
print(y.size())
|
|
|
|
# run models/network_unet.py
|