Fix (from KAIR): solved inference upsampling issue

This commit is contained in:
Emilio Martinez 2023-06-19 18:15:17 -03:00
parent 1ee3d61a2c
commit 24a673d849
1 changed files with 11 additions and 5 deletions

View File

@ -57,10 +57,14 @@ class UNetRes(nn.Module):
self.m_tail = B.conv(nc[0], out_nc, bias=bias, mode='C')
def forward(self, x0):
# h, w = x.size()[-2:]
# paddingBottom = int(np.ceil(h/8)*8-h)
# paddingRight = int(np.ceil(w/8)*8-w)
# x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x)
# Resolve upsampling size issues with padding
h, w = x0.size()[-2:]
paddingBottom = int(np.ceil(h/8)*8-h)
paddingRight = int(np.ceil(w/8)*8-w)
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
# Forward UNet
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
@ -71,7 +75,9 @@ class UNetRes(nn.Module):
x = self.m_up2(x+x3)
x = self.m_up1(x+x2)
x = self.m_tail(x+x1)
# x = x[..., :h, :w]
# Crop result to original size
x = x[..., :h, :w]
return x