| | import torch |
| | from safetensors.torch import load_file |
| |
|
| | def load_model(path='model.safetensors'): |
| | return load_file(path) |
| |
|
| | def priority_encode(i7, i6, i5, i4, i3, i2, i1, i0, weights): |
| | """8-to-3 priority encoder. Returns (y2, y1, y0, valid).""" |
| | inputs = [i7, i6, i5, i4, i3, i2, i1, i0] |
| | inp = torch.tensor([float(x) for x in inputs]) |
| | |
| | h = [] |
| | for k in range(8): |
| | hk = int((inp @ weights[f'layer1.h{k}.weight'].T + weights[f'layer1.h{k}.bias'] >= 0).item()) |
| | h.append(hk) |
| | h_tensor = torch.tensor([float(x) for x in h]) |
| | |
| | y2 = int((h_tensor @ weights['layer2.y2.weight'].T + weights['layer2.y2.bias'] >= 0).item()) |
| | y1 = int((h_tensor @ weights['layer2.y1.weight'].T + weights['layer2.y1.bias'] >= 0).item()) |
| | y0 = int((h_tensor @ weights['layer2.y0.weight'].T + weights['layer2.y0.bias'] >= 0).item()) |
| | v = int((h_tensor @ weights['layer2.v.weight'].T + weights['layer2.v.bias'] >= 0).item()) |
| | return y2, y1, y0, v |
| |
|
| | if __name__ == '__main__': |
| | w = load_model() |
| | print('Priority Encoder 8 (selected tests)') |
| | for val in [0, 1, 2, 4, 8, 16, 32, 64, 128, 255]: |
| | inputs = [(val >> (7-j)) & 1 for j in range(8)] |
| | y2, y1, y0, v = priority_encode(*inputs, w) |
| | idx = 4*y2 + 2*y1 + y0 |
| | print(f' {val:3d} ({val:08b}) -> y={idx} ({y2}{y1}{y0}) v={v}') |
| |
|