Fix (from KAIR): solved inference upsampling issue
This commit is contained in:
parent
1ee3d61a2c
commit
24a673d849
|
@ -57,10 +57,14 @@ class UNetRes(nn.Module):
|
||||||
self.m_tail = B.conv(nc[0], out_nc, bias=bias, mode='C')
|
self.m_tail = B.conv(nc[0], out_nc, bias=bias, mode='C')
|
||||||
|
|
||||||
def forward(self, x0):
|
def forward(self, x0):
|
||||||
# h, w = x.size()[-2:]
|
|
||||||
# paddingBottom = int(np.ceil(h/8)*8-h)
|
# Resolve upsampling size issues with padding
|
||||||
# paddingRight = int(np.ceil(w/8)*8-w)
|
h, w = x0.size()[-2:]
|
||||||
# x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x)
|
paddingBottom = int(np.ceil(h/8)*8-h)
|
||||||
|
paddingRight = int(np.ceil(w/8)*8-w)
|
||||||
|
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
||||||
|
|
||||||
|
# Forward UNet
|
||||||
|
|
||||||
x1 = self.m_head(x0)
|
x1 = self.m_head(x0)
|
||||||
x2 = self.m_down1(x1)
|
x2 = self.m_down1(x1)
|
||||||
|
@ -71,7 +75,9 @@ class UNetRes(nn.Module):
|
||||||
x = self.m_up2(x+x3)
|
x = self.m_up2(x+x3)
|
||||||
x = self.m_up1(x+x2)
|
x = self.m_up1(x+x2)
|
||||||
x = self.m_tail(x+x1)
|
x = self.m_tail(x+x1)
|
||||||
# x = x[..., :h, :w]
|
|
||||||
|
# Crop result to original size
|
||||||
|
x = x[..., :h, :w]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue