339 lines
13 KiB
Python
339 lines
13 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from torch.nn.utils import spectral_norm
|
|
import models.basicblock as B
|
|
import functools
|
|
import numpy as np
|
|
|
|
|
|
"""
|
|
# --------------------------------------------
|
|
# Discriminator_PatchGAN
|
|
# Discriminator_UNet
|
|
# --------------------------------------------
|
|
"""
|
|
|
|
|
|
# --------------------------------------------
|
|
# PatchGAN discriminator
|
|
# If n_layers = 3, then the receptive field is 70x70
|
|
# --------------------------------------------
|
|
class Discriminator_PatchGAN(nn.Module):
|
|
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_type='spectral'):
|
|
'''PatchGAN discriminator, receptive field = 70x70 if n_layers = 3
|
|
Args:
|
|
input_nc: number of input channels
|
|
ndf: base channel number
|
|
n_layers: number of conv layer with stride 2
|
|
norm_type: 'batch', 'instance', 'spectral', 'batchspectral', instancespectral'
|
|
Returns:
|
|
tensor: score
|
|
'''
|
|
super(Discriminator_PatchGAN, self).__init__()
|
|
self.n_layers = n_layers
|
|
norm_layer = self.get_norm_layer(norm_type=norm_type)
|
|
|
|
kw = 4
|
|
padw = int(np.ceil((kw - 1.0) / 2))
|
|
sequence = [[self.use_spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), norm_type), nn.LeakyReLU(0.2, True)]]
|
|
|
|
nf = ndf
|
|
for n in range(1, n_layers):
|
|
nf_prev = nf
|
|
nf = min(nf * 2, 512)
|
|
sequence += [[self.use_spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), norm_type),
|
|
norm_layer(nf),
|
|
nn.LeakyReLU(0.2, True)]]
|
|
|
|
nf_prev = nf
|
|
nf = min(nf * 2, 512)
|
|
sequence += [[self.use_spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), norm_type),
|
|
norm_layer(nf),
|
|
nn.LeakyReLU(0.2, True)]]
|
|
|
|
sequence += [[self.use_spectral_norm(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw), norm_type)]]
|
|
|
|
self.model = nn.Sequential()
|
|
for n in range(len(sequence)):
|
|
self.model.add_module('child' + str(n), nn.Sequential(*sequence[n]))
|
|
|
|
self.model.apply(self.weights_init)
|
|
|
|
def use_spectral_norm(self, module, norm_type='spectral'):
|
|
if 'spectral' in norm_type:
|
|
return spectral_norm(module)
|
|
return module
|
|
|
|
def get_norm_layer(self, norm_type='instance'):
|
|
if 'batch' in norm_type:
|
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
|
elif 'instance' in norm_type:
|
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
|
else:
|
|
norm_layer = functools.partial(nn.Identity)
|
|
return norm_layer
|
|
|
|
def weights_init(self, m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
m.weight.data.normal_(0.0, 0.02)
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
m.weight.data.normal_(1.0, 0.02)
|
|
m.bias.data.fill_(0)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
|
|
class Discriminator_UNet(nn.Module):
|
|
"""Defines a U-Net discriminator with spectral normalization (SN)"""
|
|
|
|
def __init__(self, input_nc=3, ndf=64):
|
|
super(Discriminator_UNet, self).__init__()
|
|
norm = spectral_norm
|
|
|
|
self.conv0 = nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv1 = norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False))
|
|
self.conv2 = norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False))
|
|
self.conv3 = norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False))
|
|
# upsample
|
|
self.conv4 = norm(nn.Conv2d(ndf * 8, ndf * 4, 3, 1, 1, bias=False))
|
|
self.conv5 = norm(nn.Conv2d(ndf * 4, ndf * 2, 3, 1, 1, bias=False))
|
|
self.conv6 = norm(nn.Conv2d(ndf * 2, ndf, 3, 1, 1, bias=False))
|
|
|
|
# extra
|
|
self.conv7 = norm(nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False))
|
|
self.conv8 = norm(nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False))
|
|
|
|
self.conv9 = nn.Conv2d(ndf, 1, 3, 1, 1)
|
|
print('using the UNet discriminator')
|
|
|
|
def forward(self, x):
|
|
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
|
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
|
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
|
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
|
|
|
# upsample
|
|
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
|
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
|
|
|
x4 = x4 + x2
|
|
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
|
|
|
x5 = x5 + x1
|
|
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
|
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
|
|
|
x6 = x6 + x0
|
|
|
|
# extra
|
|
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
|
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
|
out = self.conv9(out)
|
|
|
|
return out
|
|
|
|
|
|
# --------------------------------------------
|
|
# VGG style Discriminator with 96x96 input
|
|
# --------------------------------------------
|
|
class Discriminator_VGG_96(nn.Module):
|
|
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'):
|
|
super(Discriminator_VGG_96, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 96, 64
|
|
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C')
|
|
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 48, 64
|
|
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 24, 128
|
|
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 12, 256
|
|
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 6, 512
|
|
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 3, 512
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4,
|
|
conv5, conv6, conv7, conv8, conv9)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# --------------------------------------------
|
|
# VGG style Discriminator with 128x128 input
|
|
# --------------------------------------------
|
|
class Discriminator_VGG_128(nn.Module):
|
|
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'):
|
|
super(Discriminator_VGG_128, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 128, 64
|
|
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C')
|
|
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 64, 64
|
|
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 32, 128
|
|
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 16, 256
|
|
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 8, 512
|
|
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 4, 512
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4,
|
|
conv5, conv6, conv7, conv8, conv9)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(nn.Linear(512 * 4 * 4, 100),
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# --------------------------------------------
|
|
# VGG style Discriminator with 192x192 input
|
|
# --------------------------------------------
|
|
class Discriminator_VGG_192(nn.Module):
|
|
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'):
|
|
super(Discriminator_VGG_192, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 192, 64
|
|
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C')
|
|
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 96, 64
|
|
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 48, 128
|
|
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 24, 256
|
|
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 12, 512
|
|
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 6, 512
|
|
conv10 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
|
|
conv11 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
|
|
# 3, 512
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5,
|
|
conv6, conv7, conv8, conv9, conv10, conv11)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(nn.Linear(512 * 3 * 3, 100),
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# --------------------------------------------
|
|
# SN-VGG style Discriminator with 128x128 input
|
|
# --------------------------------------------
|
|
class Discriminator_VGG_128_SN(nn.Module):
|
|
def __init__(self):
|
|
super(Discriminator_VGG_128_SN, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 128, 64
|
|
self.lrelu = nn.LeakyReLU(0.2, True)
|
|
|
|
self.conv0 = spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
|
|
self.conv1 = spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
|
|
# 64, 64
|
|
self.conv2 = spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
|
|
self.conv3 = spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
|
|
# 32, 128
|
|
self.conv4 = spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
|
|
self.conv5 = spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
|
|
# 16, 256
|
|
self.conv6 = spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
|
|
self.conv7 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
|
# 8, 512
|
|
self.conv8 = spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
|
|
self.conv9 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
|
# 4, 512
|
|
|
|
# classifier
|
|
self.linear0 = spectral_norm(nn.Linear(512 * 4 * 4, 100))
|
|
self.linear1 = spectral_norm(nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.lrelu(self.conv0(x))
|
|
x = self.lrelu(self.conv1(x))
|
|
x = self.lrelu(self.conv2(x))
|
|
x = self.lrelu(self.conv3(x))
|
|
x = self.lrelu(self.conv4(x))
|
|
x = self.lrelu(self.conv5(x))
|
|
x = self.lrelu(self.conv6(x))
|
|
x = self.lrelu(self.conv7(x))
|
|
x = self.lrelu(self.conv8(x))
|
|
x = self.lrelu(self.conv9(x))
|
|
x = x.view(x.size(0), -1)
|
|
x = self.lrelu(self.linear0(x))
|
|
x = self.linear1(x)
|
|
return x
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
x = torch.rand(1, 3, 96, 96)
|
|
net = Discriminator_VGG_96()
|
|
net.eval()
|
|
with torch.no_grad():
|
|
y = net(x)
|
|
print(y.size())
|
|
|
|
x = torch.rand(1, 3, 128, 128)
|
|
net = Discriminator_VGG_128()
|
|
net.eval()
|
|
with torch.no_grad():
|
|
y = net(x)
|
|
print(y.size())
|
|
|
|
x = torch.rand(1, 3, 192, 192)
|
|
net = Discriminator_VGG_192()
|
|
net.eval()
|
|
with torch.no_grad():
|
|
y = net(x)
|
|
print(y.size())
|
|
|
|
x = torch.rand(1, 3, 128, 128)
|
|
net = Discriminator_VGG_128_SN()
|
|
net.eval()
|
|
with torch.no_grad():
|
|
y = net(x)
|
|
print(y.size())
|
|
|
|
# run models/network_discriminator.py
|