Spaces:
Runtime error
Runtime error
| """Discriminator architecture for ClimateGAN's GAN components (a and t) | |
| """ | |
| import functools | |
| import torch | |
| import torch.nn as nn | |
| from climategan.blocks import SpectralNorm | |
| from climategan.tutils import init_weights | |
| # from torch.optim import lr_scheduler | |
| # mainly from https://github.com/sangwoomo/instagan/blob/master/models/networks.py | |
| def create_discriminator(opts, device, no_init=False, verbose=0): | |
| disc = OmniDiscriminator(opts) | |
| if no_init: | |
| return disc | |
| for task, model in disc.items(): | |
| if isinstance(model, nn.ModuleDict): | |
| for domain, domain_model in model.items(): | |
| init_weights( | |
| domain_model, | |
| init_type=opts.dis[task].init_type, | |
| init_gain=opts.dis[task].init_gain, | |
| verbose=verbose, | |
| caller=f"create_discriminator {task} {domain}", | |
| ) | |
| else: | |
| init_weights( | |
| model, | |
| init_type=opts.dis[task].init_type, | |
| init_gain=opts.dis[task].init_gain, | |
| verbose=verbose, | |
| caller=f"create_discriminator {task}", | |
| ) | |
| return disc.to(device) | |
| def define_D( | |
| input_nc, | |
| ndf, | |
| n_layers=3, | |
| norm="batch", | |
| use_sigmoid=False, | |
| get_intermediate_features=False, | |
| num_D=1, | |
| ): | |
| norm_layer = get_norm_layer(norm_type=norm) | |
| net = MultiscaleDiscriminator( | |
| input_nc, | |
| ndf, | |
| n_layers=n_layers, | |
| norm_layer=norm_layer, | |
| use_sigmoid=use_sigmoid, | |
| get_intermediate_features=get_intermediate_features, | |
| num_D=num_D, | |
| ) | |
| return net | |
| def get_norm_layer(norm_type="instance"): | |
| if not norm_type: | |
| print("norm_type is {}, defaulting to instance") | |
| norm_type = "instance" | |
| if norm_type == "batch": | |
| norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
| elif norm_type == "instance": | |
| norm_layer = functools.partial( | |
| nn.InstanceNorm2d, affine=False, track_running_stats=False | |
| ) | |
| elif norm_type == "none": | |
| norm_layer = None | |
| else: | |
| raise NotImplementedError("normalization layer [%s] is not found" % norm_type) | |
| return norm_layer | |
| # Defines the PatchGAN discriminator with the specified arguments. | |
| class NLayerDiscriminator(nn.Module): | |
| def __init__( | |
| self, | |
| input_nc=3, | |
| ndf=64, | |
| n_layers=3, | |
| norm_layer=nn.BatchNorm2d, | |
| use_sigmoid=False, | |
| get_intermediate_features=True, | |
| ): | |
| super(NLayerDiscriminator, self).__init__() | |
| if type(norm_layer) == functools.partial: | |
| use_bias = norm_layer.func == nn.InstanceNorm2d | |
| else: | |
| use_bias = norm_layer == nn.InstanceNorm2d | |
| self.get_intermediate_features = get_intermediate_features | |
| kw = 4 | |
| padw = 1 | |
| sequence = [ | |
| [ | |
| # Use spectral normalization | |
| SpectralNorm( | |
| nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw) | |
| ), | |
| nn.LeakyReLU(0.2, True), | |
| ] | |
| ] | |
| nf_mult = 1 | |
| nf_mult_prev = 1 | |
| for n in range(1, n_layers): | |
| nf_mult_prev = nf_mult | |
| nf_mult = min(2 ** n, 8) | |
| sequence += [ | |
| [ | |
| # Use spectral normalization | |
| SpectralNorm( # TODO replace with Conv2dBlock | |
| nn.Conv2d( | |
| ndf * nf_mult_prev, | |
| ndf * nf_mult, | |
| kernel_size=kw, | |
| stride=2, | |
| padding=padw, | |
| bias=use_bias, | |
| ) | |
| ), | |
| norm_layer(ndf * nf_mult), | |
| nn.LeakyReLU(0.2, True), | |
| ] | |
| ] | |
| nf_mult_prev = nf_mult | |
| nf_mult = min(2 ** n_layers, 8) | |
| sequence += [ | |
| [ | |
| # Use spectral normalization | |
| SpectralNorm( | |
| nn.Conv2d( | |
| ndf * nf_mult_prev, | |
| ndf * nf_mult, | |
| kernel_size=kw, | |
| stride=1, | |
| padding=padw, | |
| bias=use_bias, | |
| ) | |
| ), | |
| norm_layer(ndf * nf_mult), | |
| nn.LeakyReLU(0.2, True), | |
| ] | |
| ] | |
| # Use spectral normalization | |
| sequence += [ | |
| [ | |
| SpectralNorm( | |
| nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) | |
| ) | |
| ] | |
| ] | |
| if use_sigmoid: | |
| sequence += [[nn.Sigmoid()]] | |
| # We divide the layers into groups to extract intermediate layer outputs | |
| for n in range(len(sequence)): | |
| self.add_module("model" + str(n), nn.Sequential(*sequence[n])) | |
| # self.model = nn.Sequential(*sequence) | |
| def forward(self, input): | |
| results = [input] | |
| for submodel in self.children(): | |
| intermediate_output = submodel(results[-1]) | |
| results.append(intermediate_output) | |
| get_intermediate_features = self.get_intermediate_features | |
| if get_intermediate_features: | |
| return results[1:] | |
| else: | |
| return results[-1] | |
| # def forward(self, input): | |
| # return self.model(input) | |
| # Source: https://github.com/NVIDIA/pix2pixHD | |
| class MultiscaleDiscriminator(nn.Module): | |
| def __init__( | |
| self, | |
| input_nc=3, | |
| ndf=64, | |
| n_layers=3, | |
| norm_layer=nn.BatchNorm2d, | |
| use_sigmoid=False, | |
| get_intermediate_features=True, | |
| num_D=3, | |
| ): | |
| super(MultiscaleDiscriminator, self).__init__() | |
| # self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, | |
| # use_sigmoid=False, num_D=3, getIntermFeat=False | |
| self.n_layers = n_layers | |
| self.ndf = ndf | |
| self.norm_layer = norm_layer | |
| self.use_sigmoid = use_sigmoid | |
| self.get_intermediate_features = get_intermediate_features | |
| self.num_D = num_D | |
| for i in range(self.num_D): | |
| netD = NLayerDiscriminator( | |
| input_nc=input_nc, | |
| ndf=self.ndf, | |
| n_layers=self.n_layers, | |
| norm_layer=self.norm_layer, | |
| use_sigmoid=self.use_sigmoid, | |
| get_intermediate_features=self.get_intermediate_features, | |
| ) | |
| self.add_module("discriminator_%d" % i, netD) | |
| self.downsample = nn.AvgPool2d( | |
| 3, stride=2, padding=[1, 1], count_include_pad=False | |
| ) | |
| def forward(self, input): | |
| result = [] | |
| get_intermediate_features = self.get_intermediate_features | |
| for name, D in self.named_children(): | |
| if "discriminator" not in name: | |
| continue | |
| out = D(input) | |
| if not get_intermediate_features: | |
| out = [out] | |
| result.append(out) | |
| input = self.downsample(input) | |
| return result | |
| class OmniDiscriminator(nn.ModuleDict): | |
| def __init__(self, opts): | |
| super().__init__() | |
| if "p" in opts.tasks: | |
| if opts.dis.p.use_local_discriminator: | |
| self["p"] = nn.ModuleDict( | |
| { | |
| "global": define_D( | |
| input_nc=3, | |
| ndf=opts.dis.p.ndf, | |
| n_layers=opts.dis.p.n_layers, | |
| norm=opts.dis.p.norm, | |
| use_sigmoid=opts.dis.p.use_sigmoid, | |
| get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501 | |
| num_D=opts.dis.p.num_D, | |
| ), | |
| "local": define_D( | |
| input_nc=3, | |
| ndf=opts.dis.p.ndf, | |
| n_layers=opts.dis.p.n_layers, | |
| norm=opts.dis.p.norm, | |
| use_sigmoid=opts.dis.p.use_sigmoid, | |
| get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501 | |
| num_D=opts.dis.p.num_D, | |
| ), | |
| } | |
| ) | |
| else: | |
| self["p"] = define_D( | |
| input_nc=4, # image + mask | |
| ndf=opts.dis.p.ndf, | |
| n_layers=opts.dis.p.n_layers, | |
| norm=opts.dis.p.norm, | |
| use_sigmoid=opts.dis.p.use_sigmoid, | |
| get_intermediate_features=opts.dis.p.get_intermediate_features, | |
| num_D=opts.dis.p.num_D, | |
| ) | |
| if "m" in opts.tasks: | |
| if opts.gen.m.use_advent: | |
| if opts.dis.m.architecture == "base": | |
| if opts.dis.m.gan_type == "WGAN_norm": | |
| self["m"] = nn.ModuleDict( | |
| { | |
| "Advent": get_fc_discriminator( | |
| num_classes=2, use_norm=True | |
| ) | |
| } | |
| ) | |
| else: | |
| self["m"] = nn.ModuleDict( | |
| { | |
| "Advent": get_fc_discriminator( | |
| num_classes=2, use_norm=False | |
| ) | |
| } | |
| ) | |
| elif opts.dis.m.architecture == "OmniDiscriminator": | |
| self["m"] = nn.ModuleDict( | |
| { | |
| "Advent": define_D( | |
| input_nc=2, | |
| ndf=opts.dis.m.ndf, | |
| n_layers=opts.dis.m.n_layers, | |
| norm=opts.dis.m.norm, | |
| use_sigmoid=opts.dis.m.use_sigmoid, | |
| get_intermediate_features=opts.dis.m.get_intermediate_features, # noqa: E501 | |
| num_D=opts.dis.m.num_D, | |
| ) | |
| } | |
| ) | |
| else: | |
| raise Exception("This Discriminator is currently not supported!") | |
| if "s" in opts.tasks: | |
| if opts.gen.s.use_advent: | |
| if opts.dis.s.gan_type == "WGAN_norm": | |
| self["s"] = nn.ModuleDict( | |
| {"Advent": get_fc_discriminator(num_classes=11, use_norm=True)} | |
| ) | |
| else: | |
| self["s"] = nn.ModuleDict( | |
| {"Advent": get_fc_discriminator(num_classes=11, use_norm=False)} | |
| ) | |
| def get_fc_discriminator(num_classes=2, ndf=64, use_norm=False): | |
| if use_norm: | |
| return torch.nn.Sequential( | |
| SpectralNorm( | |
| torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) | |
| ), | |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| SpectralNorm( | |
| torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1) | |
| ), | |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| SpectralNorm( | |
| torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1) | |
| ), | |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| SpectralNorm( | |
| torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1) | |
| ), | |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| SpectralNorm( | |
| torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1) | |
| ), | |
| ) | |
| else: | |
| return torch.nn.Sequential( | |
| torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1), | |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1), | |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1), | |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1), | |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1), | |
| ) | |