| | import torch |
| | import megablocks |
| |
|
| |
|
| | def randn(bs, x, y): |
| | out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) |
| | return out.cuda().to(torch.bfloat16) |
| |
|
| |
|
| | def gmm(a, b, batch_sizes, trans_b=False): |
| | batch_sizes = batch_sizes.cpu().numpy() |
| |
|
| | out = [] |
| | start = 0 |
| | for i, size in enumerate(batch_sizes): |
| | rhs = b[i, :, :].t() if trans_b else b[i, :, :] |
| | out.append(a[start : start + size, :] @ rhs) |
| | start += size |
| | return torch.cat(out) |
| |
|
| |
|
| | def test_gmm(): |
| | z = 1 |
| | m = 128 |
| | n = 128 |
| | k = 128 |
| | trans_b = False |
| | batch_sizes_on_device = False |
| | |
| | |
| |
|
| | torch.manual_seed(0) |
| | a = randn(z, m, k).view(-1, k) |
| | b = randn(z, n, k) if trans_b else randn(z, k, n) |
| | batch_sizes = torch.tensor([m] * z) |
| | if batch_sizes_on_device: |
| | batch_sizes = batch_sizes.cuda() |
| |
|
| | a.requires_grad_(True) |
| | b.requires_grad_(True) |
| | a_ref = a.detach().clone().requires_grad_(True) |
| | b_ref = b.detach().clone().requires_grad_(True) |
| |
|
| | |
| | out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
| | print("out", out) |
| |
|
| | expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) |
| |
|
| | assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}" |
| |
|
| | out.sum().backward() |
| |
|
| | expected_out.sum().backward() |
| | assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}" |
| | assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}" |
| | print("Test passed successfully!") |