deep-tempest/end-to-end/models/network_unet.py

94 lines
3.5 KiB
Python

import torch
import torch.nn as nn
import models.basicblock as B
import numpy as np
'''
# ====================
# Residual U-Net
# ====================
citation:
@article{zhang2020plug,
title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
journal={arXiv preprint},
year={2020}
}
# ====================
'''
class UNetRes(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=True):
super(UNetRes, self).__init__()
self.m_head = B.conv(in_nc, nc[0], bias=bias, mode='C')
# downsample
if downsample_mode == 'avgpool':
downsample_block = B.downsample_avgpool
elif downsample_mode == 'maxpool':
downsample_block = B.downsample_maxpool
elif downsample_mode == 'strideconv':
downsample_block = B.downsample_strideconv
else:
raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=bias, mode='2'))
self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=bias, mode='2'))
self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=bias, mode='2'))
self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)])
# upsample
if upsample_mode == 'upconv':
upsample_block = B.upsample_upconv
elif upsample_mode == 'pixelshuffle':
upsample_block = B.upsample_pixelshuffle
elif upsample_mode == 'convtranspose':
upsample_block = B.upsample_convtranspose
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=bias, mode='2'), *[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)])
self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=bias, mode='2'), *[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)])
self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=bias, mode='2'), *[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)])
self.m_tail = B.conv(nc[0], out_nc, bias=bias, mode='C')
def forward(self, x0):
# Resolve upsampling size issues with padding
h, w = x0.size()[-2:]
paddingBottom = int(np.ceil(h/8)*8-h)
paddingRight = int(np.ceil(w/8)*8-w)
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
# Forward UNet
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x+x4)
x = self.m_up2(x+x3)
x = self.m_up1(x+x2)
x = self.m_tail(x+x1)
# Crop result to original size
x = x[..., :h, :w]
return x
if __name__ == '__main__':
x = torch.rand(1,3,256,256)
net = UNetRes()
net.eval()
with torch.no_grad():
y = net(x)
print(y.size())
# run models/network_unet.py