| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| from torch.distributions import multinomial, categorical |
| import torch.optim as optim |
|
|
| import math |
|
|
| try: |
| from . import helpers as h |
| from . import ai |
| from . import scheduling as S |
| except: |
| import helpers as h |
| import ai |
| import scheduling as S |
|
|
| import math |
| import abc |
|
|
| from torch.nn.modules.conv import _ConvNd |
| from enum import Enum |
|
|
|
|
| class InferModule(nn.Module): |
| def __init__(self, *args, normal = False, ibp_init = False, **kwargs): |
| self.args = args |
| self.kwargs = kwargs |
| self.infered = False |
| self.normal = normal |
| self.ibp_init = ibp_init |
|
|
| def infer(self, in_shape, global_args = None): |
| """ this is really actually stateful. """ |
|
|
| if self.infered: |
| return self |
| self.infered = True |
|
|
| super(InferModule, self).__init__() |
| self.inShape = list(in_shape) |
| self.outShape = list(self.init(list(in_shape), *self.args, global_args = global_args, **self.kwargs)) |
| if self.outShape is None: |
| raise "init should set the out_shape" |
| |
| self.reset_parameters() |
| return self |
| |
| def reset_parameters(self): |
| if not hasattr(self,'weight') or self.weight is None: |
| return |
| n = h.product(self.weight.size()) / self.outShape[0] |
| stdv = 1 / math.sqrt(n) |
| |
| if self.ibp_init: |
| torch.nn.init.orthogonal_(self.weight.data) |
| elif self.normal: |
| self.weight.data.normal_(0, stdv) |
| self.weight.data.clamp_(-1, 1) |
| else: |
| self.weight.data.uniform_(-stdv, stdv) |
|
|
| if self.bias is not None: |
| if self.ibp_init: |
| self.bias.data.zero_() |
| elif self.normal: |
| self.bias.data.normal_(0, stdv) |
| self.bias.data.clamp_(-1, 1) |
| else: |
| self.bias.data.uniform_(-stdv, stdv) |
|
|
| def clip_norm(self): |
| if not hasattr(self, "weight"): |
| return |
| if not hasattr(self,"weight_g"): |
| if torch.__version__[0] == "0": |
| nn.utils.weight_norm(self, dim=None) |
| else: |
| nn.utils.weight_norm(self) |
| |
| self.weight_g.data.clamp_(-h.max_c_for_norm, h.max_c_for_norm) |
|
|
| if torch.__version__[0] != "0": |
| self.weight_v.data.clamp_(-h.max_c_for_norm * 10000,h.max_c_for_norm * 10000) |
| if hasattr(self, "bias"): |
| self.bias.data.clamp_(-h.max_c_for_norm * 10000, h.max_c_for_norm * 10000) |
|
|
| def regularize(self, p): |
| reg = 0 |
| if torch.__version__[0] == "0": |
| for param in self.parameters(): |
| reg += param.norm(p) |
| else: |
| if hasattr(self, "weight_g"): |
| reg += self.weight_g.norm().sum() |
| reg += self.weight_v.norm().sum() |
| elif hasattr(self, "weight"): |
| reg += self.weight.norm().sum() |
|
|
| if hasattr(self, "bias"): |
| reg += self.bias.view(-1).norm(p=p).sum() |
|
|
| return reg |
|
|
| def remove_norm(self): |
| if hasattr(self,"weight_g"): |
| torch.nn.utils.remove_weight_norm(self) |
|
|
| def showNet(self, t = ""): |
| print(t + self.__class__.__name__) |
|
|
| def printNet(self, f): |
| print(self.__class__.__name__, file=f) |
|
|
| @abc.abstractmethod |
| def forward(self, *args, **kargs): |
| pass |
|
|
| def __call__(self, *args, onyx=False, **kargs): |
| if onyx: |
| return self.forward(*args, onyx=onyx, **kargs) |
| else: |
| return super(InferModule, self).__call__(*args, **kargs) |
| |
| @abc.abstractmethod |
| def neuronCount(self): |
| pass |
|
|
| def depth(self): |
| return 0 |
|
|
| def getShapeConv(in_shape, conv_shape, stride = 1, padding = 0): |
| inChan, inH, inW = in_shape |
| outChan, kH, kW = conv_shape[:3] |
|
|
| outH = 1 + int((2 * padding + inH - kH) / stride) |
| outW = 1 + int((2 * padding + inW - kW) / stride) |
| return (outChan, outH, outW) |
|
|
| def getShapeConvTranspose(in_shape, conv_shape, stride = 1, padding = 0, out_padding=0): |
| inChan, inH, inW = in_shape |
| outChan, kH, kW = conv_shape[:3] |
|
|
| outH = (inH - 1 ) * stride - 2 * padding + kH + out_padding |
| outW = (inW - 1 ) * stride - 2 * padding + kW + out_padding |
| return (outChan, outH, outW) |
|
|
|
|
|
|
| class Linear(InferModule): |
| def init(self, in_shape, out_shape, **kargs): |
| self.in_neurons = h.product(in_shape) |
| if isinstance(out_shape, int): |
| out_shape = [out_shape] |
| self.out_neurons = h.product(out_shape) |
| |
| self.weight = torch.nn.Parameter(torch.Tensor(self.in_neurons, self.out_neurons)) |
| self.bias = torch.nn.Parameter(torch.Tensor(self.out_neurons)) |
|
|
| return out_shape |
|
|
| def forward(self, x, **kargs): |
| s = x.size() |
| x = x.view(s[0], h.product(s[1:])) |
| return (x.matmul(self.weight) + self.bias).view(s[0], *self.outShape) |
|
|
| def neuronCount(self): |
| return 0 |
|
|
| def showNet(self, t = ""): |
| print(t + "Linear out=" + str(self.out_neurons)) |
|
|
| def printNet(self, f): |
| print("Linear(" + str(self.out_neurons) + ")" ) |
|
|
| print(h.printListsNumpy(list(self.weight.transpose(1,0).data)), file= f) |
| print(h.printNumpy(self.bias), file= f) |
|
|
| class Activation(InferModule): |
| def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): |
| self.activation = [ "ReLU","Sigmoid", "Tanh", "Softplus", "ELU", "SELU"].index(activation) |
| self.activation_name = activation |
| return in_shape |
|
|
| def regularize(self, p): |
| return 0 |
|
|
| def forward(self, x, **kargs): |
| return [lambda x:x.relu(), lambda x:x.sigmoid(), lambda x:x.tanh(), lambda x:x.softplus(), lambda x:x.elu(), lambda x:x.selu()][self.activation](x) |
|
|
| def neuronCount(self): |
| return h.product(self.outShape) |
|
|
| def depth(self): |
| return 1 |
|
|
| def showNet(self, t = ""): |
| print(t + self.activation_name) |
|
|
| def printNet(self, f): |
| pass |
|
|
| class ReLU(Activation): |
| pass |
|
|
| def activation(*args, batch_norm = False, **kargs): |
| a = Activation(*args, **kargs) |
| return Seq(BatchNorm(), a) if batch_norm else a |
|
|
| class Identity(InferModule): |
| def init(self, in_shape, global_args = None, **kargs): |
| return in_shape |
|
|
| def forward(self, x, **kargs): |
| return x |
|
|
| def neuronCount(self): |
| return 0 |
|
|
| def printNet(self, f): |
| pass |
|
|
| def regularize(self, p): |
| return 0 |
|
|
| def showNet(self, *args, **kargs): |
| pass |
|
|
| class Dropout(InferModule): |
| def init(self, in_shape, p=0.5, use_2d = False, alpha_dropout = False, **kargs): |
| self.p = S.Const.initConst(p) |
| self.use_2d = use_2d |
| self.alpha_dropout = alpha_dropout |
| return in_shape |
|
|
| def forward(self, x, time = 0, **kargs): |
| if self.training: |
| with torch.no_grad(): |
| p = self.p.getVal(time = time) |
| mask = (F.dropout2d if self.use_2d else F.dropout)(h.ones(x.size()),p=p, training=True) |
| if self.alpha_dropout: |
| with torch.no_grad(): |
| keep_prob = 1 - p |
| alpha = -1.7580993408473766 |
| a = math.pow(keep_prob + alpha * alpha * keep_prob * (1 - keep_prob), -0.5) |
| b = -a * alpha * (1 - keep_prob) |
| mask = mask * a |
| return x * mask + b |
| else: |
| return x * mask |
| else: |
| return x |
|
|
| def neuronCount(self): |
| return 0 |
|
|
| def showNet(self, t = ""): |
| print(t + "Dropout p=" + str(self.p)) |
|
|
| def printNet(self, f): |
| print("Dropout(" + str(self.p) + ")" ) |
|
|
| class PrintActivation(Identity): |
| def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): |
| self.activation = activation |
| return in_shape |
|
|
| def printNet(self, f): |
| print(self.activation, file = f) |
|
|
| class PrintReLU(PrintActivation): |
| pass |
|
|
| class Conv2D(InferModule): |
|
|
| def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, activation = "ReLU", **kargs): |
| self.prev = in_shape |
| self.in_channels = in_shape[0] |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.activation = activation |
| self.use_softplus = h.default(global_args, 'use_softplus', False) |
| |
| weights_shape = (self.out_channels, self.in_channels, kernel_size, kernel_size) |
| self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) |
| if bias: |
| self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) |
| else: |
| self.bias = None |
| |
| outshape = getShapeConv(in_shape, (out_channels, kernel_size, kernel_size), stride, padding) |
| return outshape |
| |
| def forward(self, input, **kargs): |
| return input.conv2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding ) |
| |
| def printNet(self, f): |
| print("Conv2D", file = f) |
| sz = list(self.prev) |
| print(self.activation + ", filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ), file = f) |
| print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) |
| print(h.printNumpy(self.bias if self.bias is not None else h.dten(self.out_channels)), file= f) |
|
|
| def showNet(self, t = ""): |
| sz = list(self.prev) |
| print(t + "Conv2D, filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding )) |
|
|
| def neuronCount(self): |
| return 0 |
|
|
|
|
| class ConvTranspose2D(InferModule): |
|
|
| def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, out_padding=0, activation = "ReLU", **kargs): |
| self.prev = in_shape |
| self.in_channels = in_shape[0] |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.out_padding = out_padding |
| self.activation = activation |
| self.use_softplus = h.default(global_args, 'use_softplus', False) |
| |
| weights_shape = (self.in_channels, self.out_channels, kernel_size, kernel_size) |
| self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) |
| if bias: |
| self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) |
| else: |
| self.bias = None |
| |
| outshape = getShapeConvTranspose(in_shape, (out_channels, kernel_size, kernel_size), stride, padding, out_padding) |
| return outshape |
|
|
| def forward(self, input, **kargs): |
| return input.conv_transpose2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding, output_padding=self.out_padding) |
| |
| def printNet(self, f): |
| print("ConvTranspose2D", file = f) |
| print(self.activation + ", filters={}, kernel_size={}, input_shape={}".format(self.out_channels, list(self.kernel_size), list(self.prev) ), file = f) |
| print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) |
| print(h.printNumpy(self.bias), file= f) |
|
|
| def neuronCount(self): |
| return 0 |
|
|
|
|
|
|
| class MaxPool2D(InferModule): |
| def init(self, in_shape, kernel_size, stride = None, **kargs): |
| self.prev = in_shape |
| self.kernel_size = kernel_size |
| self.stride = kernel_size if stride is None else stride |
| return getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), stride) |
|
|
| def forward(self, x, **kargs): |
| return x.max_pool2d(self.kernel_size, self.stride) |
| |
| def printNet(self, f): |
| print("MaxPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) |
| |
| def neuronCount(self): |
| return h.product(self.outShape) |
|
|
| class AvgPool2D(InferModule): |
| def init(self, in_shape, kernel_size, stride = None, **kargs): |
| self.prev = in_shape |
| self.kernel_size = kernel_size |
| self.stride = kernel_size if stride is None else stride |
| out_size = getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), self.stride, padding = 1) |
| return out_size |
|
|
| def forward(self, x, **kargs): |
| if h.product(x.size()[2:]) == 1: |
| return x |
| return x.avg_pool2d(kernel_size = self.kernel_size, stride = self.stride, padding = 1) |
| |
| def printNet(self, f): |
| print("AvgPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) |
| |
| def neuronCount(self): |
| return h.product(self.outShape) |
|
|
| class AdaptiveAvgPool2D(InferModule): |
| def init(self, in_shape, out_shape, **kargs): |
| self.prev = in_shape |
| self.out_shape = list(out_shape) |
| return [in_shape[0]] + self.out_shape |
|
|
| def forward(self, x, **kargs): |
| return x.adaptive_avg_pool2d(self.out_shape) |
| |
| def printNet(self, f): |
| print("AdaptiveAvgPool2D out_Shape={} input_shape={}".format(list(self.out_shape), list(self.prev[1:]+self.prev[:1]) ), file = f) |
| |
| def neuronCount(self): |
| return h.product(self.outShape) |
|
|
| class Normalize(InferModule): |
| def init(self, in_shape, mean, std, **kargs): |
| self.mean_v = mean |
| self.std_v = std |
| self.mean = h.dten(mean) |
| self.std = 1 / h.dten(std) |
| return in_shape |
|
|
| def forward(self, x, **kargs): |
| mean_ex = self.mean.view(self.mean.shape[0],1,1).expand(*x.size()[1:]) |
| std_ex = self.std.view(self.std.shape[0],1,1).expand(*x.size()[1:]) |
| return (x - mean_ex) * std_ex |
|
|
| def neuronCount(self): |
| return 0 |
|
|
| def printNet(self, f): |
| print("Normalize mean={} std={}".format(self.mean_v, self.std_v), file = f) |
|
|
| def showNet(self, t = ""): |
| print(t + "Normalize mean={} std={}".format(self.mean_v, self.std_v)) |
|
|
| class Flatten(InferModule): |
| def init(self, in_shape, **kargs): |
| return h.product(in_shape) |
| |
| def forward(self, x, **kargs): |
| s = x.size() |
| return x.view(s[0], h.product(s[1:])) |
|
|
| def neuronCount(self): |
| return 0 |
|
|
| class BatchNorm(InferModule): |
| def init(self, in_shape, track_running_stats = True, momentum = 0.1, eps=1e-5, **kargs): |
| self.gamma = torch.nn.Parameter(torch.Tensor(*in_shape)) |
| self.beta = torch.nn.Parameter(torch.Tensor(*in_shape)) |
| self.eps = eps |
| self.track_running_stats = track_running_stats |
| self.momentum = momentum |
|
|
| self.running_mean = None |
| self.running_var = None |
|
|
| self.num_batches_tracked = 0 |
| return in_shape |
|
|
| def reset_parameters(self): |
| self.gamma.data.fill_(1) |
| self.beta.data.zero_() |
|
|
| def forward(self, x, **kargs): |
| exponential_average_factor = 0.0 |
| if self.training and self.track_running_stats: |
| |
| if self.num_batches_tracked is not None: |
| self.num_batches_tracked += 1 |
| if self.momentum is None: |
| exponential_average_factor = 1.0 / float(self.num_batches_tracked) |
| else: |
| exponential_average_factor = self.momentum |
|
|
| new_mean = x.vanillaTensorPart().detach().mean(dim=0) |
| new_var = x.vanillaTensorPart().detach().var(dim=0, unbiased=False) |
| if torch.isnan(new_var * 0).any(): |
| return x |
| if self.training: |
| self.running_mean = (1 - exponential_average_factor) * self.running_mean + exponential_average_factor * new_mean if self.running_mean is not None else new_mean |
| if self.running_var is None: |
| self.running_var = new_var |
| else: |
| q = (1 - exponential_average_factor) * self.running_var |
| r = exponential_average_factor * new_var |
| self.running_var = q + r |
|
|
| if self.track_running_stats and self.running_mean is not None and self.running_var is not None: |
| new_mean = self.running_mean |
| new_var = self.running_var |
| |
| diver = 1 / (new_var + self.eps).sqrt() |
|
|
| if torch.isnan(diver).any(): |
| print("Really shouldn't happen ever") |
| return x |
| else: |
| out = (x - new_mean) * diver * self.gamma + self.beta |
| return out |
| |
| def neuronCount(self): |
| return 0 |
|
|
| class Unflatten2d(InferModule): |
| def init(self, in_shape, w, **kargs): |
| self.w = w |
| self.outChan = int(h.product(in_shape) / (w * w)) |
| |
| return (self.outChan, self.w, self.w) |
| |
| def forward(self, x, **kargs): |
| s = x.size() |
| return x.view(s[0], self.outChan, self.w, self.w) |
|
|
| def neuronCount(self): |
| return 0 |
|
|
|
|
| class View(InferModule): |
| def init(self, in_shape, out_shape, **kargs): |
| assert(h.product(in_shape) == h.product(out_shape)) |
| return out_shape |
| |
| def forward(self, x, **kargs): |
| s = x.size() |
| return x.view(s[0], *self.outShape) |
|
|
| def neuronCount(self): |
| return 0 |
| |
| class Seq(InferModule): |
| def init(self, in_shape, *layers, **kargs): |
| self.layers = layers |
| self.net = nn.Sequential(*layers) |
| self.prev = in_shape |
| for s in layers: |
| in_shape = s.infer(in_shape, **kargs).outShape |
| return in_shape |
| |
| def forward(self, x, **kargs): |
| |
| for l in self.layers: |
| x = l(x, **kargs) |
| return x |
|
|
| def clip_norm(self): |
| for l in self.layers: |
| l.clip_norm() |
|
|
| def regularize(self, p): |
| return sum(n.regularize(p) for n in self.layers) |
|
|
| def remove_norm(self): |
| for l in self.layers: |
| l.remove_norm() |
|
|
| def printNet(self, f): |
| for l in self.layers: |
| l.printNet(f) |
|
|
| def showNet(self, *args, **kargs): |
| for l in self.layers: |
| l.showNet(*args, **kargs) |
|
|
| def neuronCount(self): |
| return sum([l.neuronCount() for l in self.layers ]) |
|
|
| def depth(self): |
| return sum([l.depth() for l in self.layers ]) |
| |
| def FFNN(layers, last_lin = False, last_zono = False, **kargs): |
| starts = layers |
| ends = [] |
| if last_lin: |
| ends = ([CorrelateAll(only_train=False)] if last_zono else []) + [PrintActivation(activation = "Affine"), Linear(layers[-1],**kargs)] |
| starts = layers[:-1] |
| |
| return Seq(*([ Seq(PrintActivation(**kargs), Linear(s, **kargs), activation(**kargs)) for s in starts] + ends)) |
|
|
| def Conv(*args, **kargs): |
| return Seq(Conv2D(*args, **kargs), activation(**kargs)) |
|
|
| def ConvTranspose(*args, **kargs): |
| return Seq(ConvTranspose2D(*args, **kargs), activation(**kargs)) |
|
|
| MP = MaxPool2D |
|
|
| def LeNet(conv_layers, ly = [], bias = True, normal=False, **kargs): |
| def transfer(tp): |
| if isinstance(tp, InferModule): |
| return tp |
| if isinstance(tp[0], str): |
| return MaxPool2D(*tp[1:]) |
| return Conv(out_channels = tp[0], kernel_size = tp[1], stride = tp[-1] if len(tp) == 4 else 1, bias=bias, normal=normal, **kargs) |
| conv = [transfer(s) for s in conv_layers] |
| return Seq(*conv, FFNN(ly, **kargs, bias=bias)) if len(ly) > 0 else Seq(*conv) |
|
|
| def InvLeNet(ly, w, conv_layers, bias = True, normal=False, **kargs): |
| def transfer(tp): |
| return ConvTranspose(out_channels = tp[0], kernel_size = tp[1], stride = tp[2], padding = tp[3], out_padding = tp[4], bias=False, normal=normal) |
| |
| return Seq(FFNN(ly, bias=bias), Unflatten2d(w), *[transfer(s) for s in conv_layers]) |
|
|
| class FromByteImg(InferModule): |
| def init(self, in_shape, **kargs): |
| return in_shape |
| |
| def forward(self, x, **kargs): |
| return x.to_dtype()/ 256. |
|
|
| def neuronCount(self): |
| return 0 |
| |
| class Skip(InferModule): |
| def init(self, in_shape, net1, net2, **kargs): |
| self.net1 = net1.infer(in_shape, **kargs) |
| self.net2 = net2.infer(in_shape, **kargs) |
| assert(net1.outShape[1:] == net2.outShape[1:]) |
| return [ net1.outShape[0] + net2.outShape[0] ] + net1.outShape[1:] |
| |
| def forward(self, x, **kargs): |
| r1 = self.net1(x, **kargs) |
| r2 = self.net2(x, **kargs) |
| return r1.cat(r2, dim=1) |
|
|
| def regularize(self, p): |
| return self.net1.regularize(p) + self.net2.regularize(p) |
|
|
| def clip_norm(self): |
| self.net1.clip_norm() |
| self.net2.clip_norm() |
|
|
| def remove_norm(self): |
| self.net1.remove_norm() |
| self.net2.remove_norm() |
|
|
| def neuronCount(self): |
| return self.net1.neuronCount() + self.net2.neuronCount() |
|
|
| def printNet(self, f): |
| print("SkipNet1", file=f) |
| self.net1.printNet(f) |
| print("SkipNet2", file=f) |
| self.net2.printNet(f) |
| print("SkipCat dim=1", file=f) |
|
|
| def showNet(self, t = ""): |
| print(t+"SkipNet1") |
| self.net1.showNet(" "+t) |
| print(t+"SkipNet2") |
| self.net2.showNet(" "+t) |
| print(t+"SkipCat dim=1") |
|
|
| class ParSum(InferModule): |
| def init(self, in_shape, net1, net2, **kargs): |
| self.net1 = net1.infer(in_shape, **kargs) |
| self.net2 = net2.infer(in_shape, **kargs) |
| assert(net1.outShape == net2.outShape) |
| return net1.outShape |
| |
|
|
|
|
| def forward(self, x, **kargs): |
| |
| r1 = self.net1(x, **kargs) |
| r2 = self.net2(x, **kargs) |
| return x.addPar(r1,r2) |
|
|
| def clip_norm(self): |
| self.net1.clip_norm() |
| self.net2.clip_norm() |
|
|
| def remove_norm(self): |
| self.net1.remove_norm() |
| self.net2.remove_norm() |
|
|
| def neuronCount(self): |
| return self.net1.neuronCount() + self.net2.neuronCount() |
| |
| def depth(self): |
| return max(self.net1.depth(), self.net2.depth()) |
|
|
| def printNet(self, f): |
| print("ParNet1", file=f) |
| self.net1.printNet(f) |
| print("ParNet2", file=f) |
| self.net2.printNet(f) |
| print("ParCat dim=1", file=f) |
|
|
| def showNet(self, t = ""): |
| print(t + "ParNet1") |
| self.net1.showNet(" "+t) |
| print(t + "ParNet2") |
| self.net2.showNet(" "+t) |
| print(t + "ParSum") |
|
|
| class ToZono(Identity): |
| def init(self, in_shape, customRelu = None, only_train = False, **kargs): |
| self.customRelu = customRelu |
| self.only_train = only_train |
| return in_shape |
|
|
| def forward(self, x, **kargs): |
| return self.abstract_forward(x, **kargs) if self.training or not self.only_train else x |
|
|
| def abstract_forward(self, x, **kargs): |
| return x.abstractApplyLeaf('hybrid_to_zono', customRelu = self.customRelu) |
|
|
| def showNet(self, t = ""): |
| print(t + self.__class__.__name__ + " only_train=" + str(self.only_train)) |
|
|
| class CorrelateAll(ToZono): |
| def abstract_forward(self, x, **kargs): |
| return x.abstractApplyLeaf('hybrid_to_zono',correlate=True, customRelu = self.customRelu) |
|
|
| class ToHZono(ToZono): |
| def abstract_forward(self, x, **kargs): |
| return x.abstractApplyLeaf('zono_to_hybrid',customRelu = self.customRelu) |
|
|
| class Concretize(ToZono): |
| def init(self, in_shape, only_train = True, **kargs): |
| self.only_train = only_train |
| return in_shape |
|
|
| def abstract_forward(self, x, **kargs): |
| return x.abstractApplyLeaf('concretize') |
|
|
| |
| class CorrRand(Concretize): |
| def init(self, in_shape, num_correlate, only_train = True, **kargs): |
| self.only_train = only_train |
| self.num_correlate = num_correlate |
| return in_shape |
| |
| def abstract_forward(self, x): |
| return x.abstractApplyLeaf("stochasticCorrelate", self.num_correlate) |
|
|
| def showNet(self, t = ""): |
| print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " num_correlate="+ str(self.num_correlate)) |
|
|
| class CorrMaxK(CorrRand): |
| def abstract_forward(self, x): |
| return x.abstractApplyLeaf("correlateMaxK", self.num_correlate) |
|
|
|
|
| class CorrMaxPool2D(Concretize): |
| def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.head_beta, **kargs): |
| self.only_train = only_train |
| self.kernel_size = kernel_size |
| self.max_type = max_type |
| return in_shape |
| |
| def abstract_forward(self, x): |
| return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type) |
|
|
| def showNet(self, t = ""): |
| print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +str(self.max_type)) |
|
|
| class CorrMaxPool3D(Concretize): |
| def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.only_beta, **kargs): |
| self.only_train = only_train |
| self.kernel_size = kernel_size |
| self.max_type = max_type |
| return in_shape |
| |
| def abstract_forward(self, x): |
| return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type, max_pool = F.max_pool3d) |
|
|
| def showNet(self, t = ""): |
| print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +self.max_type) |
|
|
| class CorrFix(Concretize): |
| def init(self,in_shape, k, only_train = True, **kargs): |
| self.k = k |
| self.only_train = only_train |
| return in_shape |
| |
| def abstract_forward(self, x): |
| sz = x.size() |
| """ |
| # for more control in the future |
| indxs_1 = torch.arange(start = 0, end = sz[1], step = math.ceil(sz[1] / self.dims[1]) ) |
| indxs_2 = torch.arange(start = 0, end = sz[2], step = math.ceil(sz[2] / self.dims[2]) ) |
| indxs_3 = torch.arange(start = 0, end = sz[3], step = math.ceil(sz[3] / self.dims[3]) ) |
| |
| indxs = torch.stack(torch.meshgrid((indxs_1,indxs_2,indxs_3)), dim=3).view(-1,3) |
| """ |
| szm = h.product(sz[1:]) |
| indxs = torch.arange(start = 0, end = szm, step = math.ceil(szm / self.k)) |
| indxs = indxs.unsqueeze(0).expand(sz[0], indxs.size()[0]) |
|
|
| |
| return x.abstractApplyLeaf("correlate", indxs) |
|
|
| def showNet(self, t = ""): |
| print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.k)) |
|
|
|
|
| class DecorrRand(Concretize): |
| def init(self, in_shape, num_decorrelate, only_train = True, **kargs): |
| self.only_train = only_train |
| self.num_decorrelate = num_decorrelate |
| return in_shape |
| |
| def abstract_forward(self, x): |
| return x.abstractApplyLeaf("stochasticDecorrelate", self.num_decorrelate) |
|
|
| class DecorrMin(Concretize): |
| def init(self, in_shape, num_decorrelate, only_train = True, num_to_keep = False, **kargs): |
| self.only_train = only_train |
| self.num_decorrelate = num_decorrelate |
| self.num_to_keep = num_to_keep |
| return in_shape |
| |
| def abstract_forward(self, x): |
| return x.abstractApplyLeaf("decorrelateMin", self.num_decorrelate, num_to_keep = self.num_to_keep) |
|
|
|
|
| def showNet(self, t = ""): |
| print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.num_decorrelate) + " num_to_keep=" + str(self.num_to_keep) ) |
|
|
| class DeepLoss(ToZono): |
| def init(self, in_shape, bw = 0.01, act = F.relu, **kargs): |
| self.only_train = True |
| self.bw = S.Const.initConst(bw) |
| self.act = act |
| return in_shape |
|
|
| def abstract_forward(self, x, **kargs): |
| if x.isPoint(): |
| return x |
| return ai.TaggedDomain(x, self.MLoss(self, x)) |
|
|
| class MLoss(): |
| def __init__(self, obj, x): |
| self.obj = obj |
| self.x = x |
|
|
| def loss(self, a, *args, lr = 1, time = 0, **kargs): |
| bw = self.obj.bw.getVal(time = time) |
| pre_loss = a.loss(*args, time = time, **kargs, lr = lr * (1 - bw)) |
| if bw <= 0.0: |
| return pre_loss |
| return (1 - bw) * pre_loss + bw * self.x.deep_loss(act = self.obj.act) |
|
|
| def showNet(self, t = ""): |
| print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " bw="+ str(self.bw) + " act=" + str(self.act) ) |
|
|
| class IdentLoss(DeepLoss): |
| def abstract_forward(self, x, **kargs): |
| return x |
| |
| def SkipNet(net1, net2, ffnn, **kargs): |
| return Seq(Skip(net1,net2), FFNN(ffnn, **kargs)) |
|
|
| def WideBlock(out_filters, downsample=False, k=3, bias=False, **kargs): |
| if not downsample: |
| k_first = 3 |
| skip_stride = 1 |
| k_skip = 1 |
| else: |
| k_first = 4 |
| skip_stride = 2 |
| k_skip = 2 |
|
|
| |
| blockA = Conv2D(out_filters, kernel_size=k_skip, stride=skip_stride, padding=0, bias=bias, normal=True, **kargs) |
|
|
| |
| blockB = Seq( Conv(out_filters, kernel_size = k_first, stride = skip_stride, padding = 1, bias=bias, normal=True, **kargs) |
| , Conv2D(out_filters, kernel_size = k, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) |
| return Seq(ParSum(blockA, blockB), activation(**kargs)) |
|
|
|
|
|
|
| def BasicBlock(in_planes, planes, stride=1, bias = False, skip_net = False, **kargs): |
| block = Seq( Conv(planes, kernel_size = 3, stride = stride, padding = 1, bias=bias, normal=True, **kargs) |
| , Conv2D(planes, kernel_size = 3, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) |
|
|
| if stride != 1 or in_planes != planes: |
| block = ParSum(block, Conv2D(planes, kernel_size=1, stride=stride, bias=bias, normal=True, **kargs)) |
| elif not skip_net: |
| block = ParSum(block, Identity()) |
| return Seq(block, activation(**kargs)) |
|
|
| |
| def ResNet(blocksList, extra = [], bias = False, **kargs): |
|
|
| layers = [] |
| in_planes = 64 |
| planes = 64 |
| stride = 0 |
| for num_blocks in blocksList: |
| if stride < 2: |
| stride += 1 |
|
|
| strides = [stride] + [1]*(num_blocks-1) |
| for stride in strides: |
| layers.append(BasicBlock(in_planes, planes, stride, bias = bias, **kargs)) |
| in_planes = planes |
| planes *= 2 |
|
|
| print("RESlayers: ", len(layers)) |
| for e,l in extra: |
| layers[l] = Seq(layers[l], e) |
| |
| return Seq(Conv(64, kernel_size=3, stride=1, padding = 1, bias=bias, normal=True, printShape=True), |
| *layers) |
|
|
|
|
|
|
| def DenseNet(growthRate, depth, reduction, num_classes, bottleneck = True): |
|
|
| def Bottleneck(growthRate): |
| interChannels = 4*growthRate |
|
|
| n = Seq( ReLU(), |
| Conv2D(interChannels, kernel_size=1, bias=True, ibp_init = True), |
| ReLU(), |
| Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True) |
| ) |
|
|
| return Skip(Identity(), n) |
|
|
| def SingleLayer(growthRate): |
| n = Seq( ReLU(), |
| Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True)) |
| return Skip(Identity(), n) |
|
|
| def Transition(nOutChannels): |
| return Seq( ReLU(), |
| Conv2D(nOutChannels, kernel_size = 1, bias = True, ibp_init = True), |
| AvgPool2D(kernel_size=2)) |
|
|
| def make_dense(growthRate, nDenseBlocks, bottleneck): |
| return Seq(*[Bottleneck(growthRate) if bottleneck else SingleLayer(growthRate) for i in range(nDenseBlocks)]) |
|
|
| nDenseBlocks = (depth-4) // 3 |
| if bottleneck: |
| nDenseBlocks //= 2 |
|
|
| nChannels = 2*growthRate |
| conv1 = Conv2D(nChannels, kernel_size=3, padding=1, bias=True, ibp_init = True) |
| dense1 = make_dense(growthRate, nDenseBlocks, bottleneck) |
| nChannels += nDenseBlocks * growthRate |
| nOutChannels = int(math.floor(nChannels*reduction)) |
| trans1 = Transition(nOutChannels) |
|
|
| nChannels = nOutChannels |
| dense2 = make_dense(growthRate, nDenseBlocks, bottleneck) |
| nChannels += nDenseBlocks*growthRate |
| nOutChannels = int(math.floor(nChannels*reduction)) |
| trans2 = Transition(nOutChannels) |
| |
| nChannels = nOutChannels |
| dense3 = make_dense(growthRate, nDenseBlocks, bottleneck) |
|
|
| return Seq(conv1, dense1, trans1, dense2, trans2, dense3, |
| ReLU(), |
| AvgPool2D(kernel_size=8), |
| CorrelateAll(only_train=False, ignore_point = True), |
| Linear(num_classes, ibp_init = True)) |
|
|
|
|