| import torch |
| import torch.nn as nn |
| import math |
|
|
| try: |
| from . import helpers as h |
| except: |
| import helpers as h |
|
|
|
|
|
|
| class Const(): |
| def __init__(self, c): |
| self.c = c if c is None else float(c) |
|
|
| def getVal(self, c = None, **kargs): |
| return self.c if self.c is not None else c |
|
|
| def __str__(self): |
| return str(self.c) |
|
|
| def initConst(x): |
| return x if isinstance(x, Const) else Const(x) |
|
|
| class Lin(Const): |
| def __init__(self, start, end, steps, initial = 0, quant = False): |
| self.start = float(start) |
| self.end = float(end) |
| self.steps = float(steps) |
| self.initial = float(initial) |
| self.quant = quant |
|
|
| def getVal(self, time = 0, **kargs): |
| if self.quant: |
| time = math.floor(time) |
| return (self.end - self.start) * max(0,min(1, float(time - self.initial) / self.steps)) + self.start |
|
|
| def __str__(self): |
| return "Lin(%s,%s,%s,%s, quant=%s)".format(str(self.start), str(self.end), str(self.steps), str(self.initial), str(self.quant)) |
|
|
| class Until(Const): |
| def __init__(self, thresh, a, b): |
| self.a = Const.initConst(a) |
| self.b = Const.initConst(b) |
| self.thresh = thresh |
|
|
| def getVal(self, *args, time = 0, **kargs): |
| return self.a.getVal(*args, time = time, **kargs) if time < self.thresh else self.b.getVal(*args, time = time - self.thresh, **kargs) |
|
|
| def __str__(self): |
| return "Until(%s, %s, %s)" % (str(self.thresh), str(self.a), str(self.b)) |
|
|
| class Scale(Const): |
| def __init__(self, c): |
| self.c = Const.initConst(c) |
|
|
| def getVal(self, *args, **kargs): |
| c = self.c.getVal(*args, **kargs) |
| if c == 0: |
| return 0 |
| assert c >= 0 |
| assert c < 1 |
| return c / (1 - c) |
|
|
| def __str__(self): |
| return "Scale(%s)" % str(self.c) |
|
|
| def MixLin(*args, **kargs): |
| return Scale(Lin(*args, **kargs)) |
|
|
| class Normal(Const): |
| def __init__(self, c): |
| self.c = Const.initConst(c) |
|
|
| def getVal(self, *args, shape = [1], **kargs): |
| c = self.c.getVal(*args, shape = shape, **kargs) |
| return torch.randn(shape, device = h.device).abs() * c |
|
|
| def __str__(self): |
| return "Normal(%s)" % str(self.c) |
|
|
| class Clip(Const): |
| def __init__(self, c, l, u): |
| self.c = Const.initConst(c) |
| self.l = Const.initConst(l) |
| self.u = Const.initConst(u) |
|
|
| def getVal(self, *args, **kargs): |
| c = self.c.getVal(*args, **kargs) |
| l = self.l.getVal(*args, **kargs) |
| u = self.u.getVal(*args, **kargs) |
| if isinstance(c, float): |
| return min(max(c,l),u) |
| else: |
| return c.clamp(l,u) |
|
|
| def __str__(self): |
| return "Clip(%s, %s, %s)" % (str(self.c), str(self.l), str(self.u)) |
|
|
| class Fun(Const): |
| def __init__(self, foo): |
| self.foo = foo |
| def getVal(self, *args, **kargs): |
| return self.foo(*args, **kargs) |
| |
| def __str__(self): |
| return "Fun(...)" |
|
|
| class Complement(Const): |
| def __init__(self, c): |
| self.c = Const.initConst(c) |
|
|
| def getVal(self, *args, **kargs): |
| c = self.c.getVal(*args, **kargs) |
| assert c >= 0 |
| assert c <= 1 |
| return 1 - c |
|
|
| def __str__(self): |
| return "Complement(%s)" % str(self.c) |
|
|