Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from climategan.blocks import ( | |
| BaseDecoder, | |
| Conv2dBlock, | |
| InterpolateNearest2d, | |
| SPADEResnetBlock, | |
| ) | |
| def create_mask_decoder(opts, no_init=False, verbose=0): | |
| if opts.gen.m.use_spade: | |
| if verbose > 0: | |
| print(" - Add Spade Mask Decoder") | |
| assert "d" in opts.tasks or "s" in opts.tasks | |
| return MaskSpadeDecoder(opts) | |
| else: | |
| if verbose > 0: | |
| print(" - Add Base Mask Decoder") | |
| return MaskBaseDecoder(opts) | |
| class MaskBaseDecoder(BaseDecoder): | |
| def __init__(self, opts): | |
| low_level_feats_dim = -1 | |
| use_v3 = opts.gen.encoder.architecture == "deeplabv3" | |
| use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet" | |
| use_low = opts.gen.m.use_low_level_feats | |
| use_dada = ("d" in opts.tasks) and opts.gen.m.use_dada | |
| if use_v3 and use_mobile_net: | |
| input_dim = 320 | |
| if use_low: | |
| low_level_feats_dim = 24 | |
| elif use_v3: | |
| input_dim = 2048 | |
| if use_low: | |
| low_level_feats_dim = 256 | |
| else: | |
| input_dim = 2048 | |
| super().__init__( | |
| n_upsample=opts.gen.m.n_upsample, | |
| n_res=opts.gen.m.n_res, | |
| input_dim=input_dim, | |
| proj_dim=opts.gen.m.proj_dim, | |
| output_dim=opts.gen.m.output_dim, | |
| norm=opts.gen.m.norm, | |
| activ=opts.gen.m.activ, | |
| pad_type=opts.gen.m.pad_type, | |
| output_activ="none", | |
| low_level_feats_dim=low_level_feats_dim, | |
| use_dada=use_dada, | |
| ) | |
| class MaskSpadeDecoder(nn.Module): | |
| def __init__(self, opts): | |
| """Create a SPADE-based decoder, which forwards z and the conditioning | |
| tensors seg (in the original paper, conditioning is on a semantic map only). | |
| All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink | |
| the channel dimension, and an upsampling is applied after each. Therefore | |
| 2 upsamplings at this point. Then, for each remaining upsamplings | |
| (w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3 | |
| channels, the number of channels is therefore: | |
| final_nc = channels(z) * 2 ** (spade_n_up - 2) | |
| Args: | |
| latent_dim (tuple): z's shape (only the number of channels matters) | |
| cond_nc (int): conditioning tensor's expected number of channels | |
| spade_n_up (int): Number of total upsamplings from z | |
| spade_use_spectral_norm (bool): use spectral normalization? | |
| spade_param_free_norm (str): norm to use before SPADE de-normalization | |
| spade_kernel_size (int): SPADE conv layers' kernel size | |
| Returns: | |
| [type]: [description] | |
| """ | |
| super().__init__() | |
| self.opts = opts | |
| latent_dim = opts.gen.m.spade.latent_dim | |
| cond_nc = opts.gen.m.spade.cond_nc | |
| spade_use_spectral_norm = opts.gen.m.spade.spade_use_spectral_norm | |
| spade_param_free_norm = opts.gen.m.spade.spade_param_free_norm | |
| if self.opts.gen.m.spade.activations.all_lrelu: | |
| spade_activation = "lrelu" | |
| else: | |
| spade_activation = None | |
| spade_kernel_size = 3 | |
| self.num_layers = opts.gen.m.spade.num_layers | |
| self.z_nc = latent_dim | |
| if ( | |
| opts.gen.encoder.architecture == "deeplabv3" | |
| and opts.gen.deeplabv3.backbone == "mobilenet" | |
| ): | |
| self.input_dim = [320, 24] | |
| self.low_level_conv = Conv2dBlock( | |
| self.input_dim[1], | |
| self.input_dim[0], | |
| 3, | |
| padding=1, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="spectral_batch", | |
| ) | |
| self.merge_feats_conv = Conv2dBlock( | |
| self.input_dim[0] * 2, | |
| self.z_nc, | |
| 3, | |
| padding=1, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="spectral_batch", | |
| ) | |
| elif ( | |
| opts.gen.encoder.architecture == "deeplabv3" | |
| and opts.gen.deeplabv3.backbone == "resnet" | |
| ): | |
| self.input_dim = [2048, 256] | |
| if self.opts.gen.m.use_proj: | |
| proj_dim = self.opts.gen.m.proj_dim | |
| self.low_level_conv = Conv2dBlock( | |
| self.input_dim[1], | |
| proj_dim, | |
| 3, | |
| padding=1, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="spectral_batch", | |
| ) | |
| self.high_level_conv = Conv2dBlock( | |
| self.input_dim[0], | |
| proj_dim, | |
| 3, | |
| padding=1, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="spectral_batch", | |
| ) | |
| self.merge_feats_conv = Conv2dBlock( | |
| proj_dim * 2, | |
| self.z_nc, | |
| 3, | |
| padding=1, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="spectral_batch", | |
| ) | |
| else: | |
| self.low_level_conv = Conv2dBlock( | |
| self.input_dim[1], | |
| self.input_dim[0], | |
| 3, | |
| padding=1, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="spectral_batch", | |
| ) | |
| self.merge_feats_conv = Conv2dBlock( | |
| self.input_dim[0] * 2, | |
| self.z_nc, | |
| 3, | |
| padding=1, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="spectral_batch", | |
| ) | |
| elif opts.gen.encoder.architecture == "deeplabv2": | |
| self.input_dim = 2048 | |
| self.fc_conv = Conv2dBlock( | |
| self.input_dim, | |
| self.z_nc, | |
| 3, | |
| padding=1, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="spectral_batch", | |
| ) | |
| else: | |
| raise ValueError("Unknown encoder type") | |
| self.spade_blocks = [] | |
| for i in range(self.num_layers): | |
| self.spade_blocks.append( | |
| SPADEResnetBlock( | |
| int(self.z_nc / (2**i)), | |
| int(self.z_nc / (2 ** (i + 1))), | |
| cond_nc, | |
| spade_use_spectral_norm, | |
| spade_param_free_norm, | |
| spade_kernel_size, | |
| spade_activation, | |
| ) | |
| ) | |
| self.spade_blocks = nn.Sequential(*self.spade_blocks) | |
| self.final_nc = int(self.z_nc / (2**self.num_layers)) | |
| self.mask_conv = Conv2dBlock( | |
| self.final_nc, | |
| 1, | |
| 3, | |
| padding=1, | |
| activation="none", | |
| pad_type="reflect", | |
| norm="spectral", | |
| ) | |
| self.upsample = InterpolateNearest2d(scale_factor=2) | |
| def forward(self, z, cond, z_depth=None): | |
| if isinstance(z, (list, tuple)): | |
| z_h, z_l = z | |
| if self.opts.gen.m.use_proj: | |
| z_l = self.low_level_conv(z_l) | |
| z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear") | |
| z_h = self.high_level_conv(z_h) | |
| else: | |
| z_l = self.low_level_conv(z_l) | |
| z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear") | |
| z = torch.cat([z_h, z_l], axis=1) | |
| y = self.merge_feats_conv(z) | |
| else: | |
| y = self.fc_conv(z) | |
| for i in range(self.num_layers): | |
| y = self.spade_blocks[i](y, cond) | |
| y = self.upsample(y) | |
| y = self.mask_conv(y) | |
| return y | |
| def __str__(self): | |
| return "MaskerSpadeDecoder" | |