| | |
| | |
| | |
| |
|
| | import gc |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | try: |
| | import conv1d_cpp |
| | except: |
| | pass |
| | from .utils import column_split |
| |
|
| | IIR_PREFILL_MODES = [ |
| | "recurrence", |
| | "modal-fft", |
| | "hybrid-modal-recurrence", |
| | "modal-scan", |
| | "canonical-fft", |
| | "iir-fir-caching", |
| | ] |
| |
|
| |
|
| | def canonicalize_modal_system(poles, residues): |
| | """Canonicalize a modal system. |
| | |
| | Args: |
| | poles (Tensor): The poles of the system. |
| | residues (Tensor): The residues of the system. |
| | |
| | Returns: |
| | Tuple[Tensor, Tensor]: The canonicalized poles and residues. |
| | """ |
| | raise NotImplementedError |
| |
|
| |
|
| | def list_tensors(idx): |
| | for obj in gc.get_objects(): |
| | try: |
| | if torch.is_tensor(obj) and isinstance(obj, torch.Tensor): |
| | |
| | print(type(obj), obj.size()) |
| | el = obj[0] |
| | with open(f"tensors_{idx}.txt", "a") as f: |
| | f.write(f"{type(obj)} {obj.size()} {el}\n") |
| | except Exception as e: |
| | pass |
| |
|
| |
|
| | class HyenaInferenceEngine: |
| | def __init__( |
| | self, |
| | fir_fn=None, |
| | iir_prefill_style="modal-fft", |
| | layer_idx=None, |
| | ) -> None: |
| | self.fir_fn = fir_fn |
| | assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}" |
| | self.iir_prefill_style = iir_prefill_style |
| | self.layer_idx = layer_idx |
| | self.low_mem_mode = False |
| |
|
| | def parallel_fir( |
| | self, |
| | fir_fn, |
| | u, |
| | weight, |
| | bias, |
| | L, |
| | fir_length=3, |
| | inference_params=None, |
| | prefill_mode=None, |
| | padding_mask=None, |
| | ): |
| | """Compute the output state of the long convolutional filter.""" |
| | |
| | if fir_fn != torch.nn.functional.conv1d: |
| | z_pre = fir_fn(u)[:, :L] |
| | z_pre = z_pre.permute(0, 2, 1) |
| | else: |
| | u = u.permute(0, 2, 1) |
| | z_pre = fir_fn( |
| | u, |
| | weight, |
| | bias=None, |
| | stride=1, |
| | padding=fir_length - 1, |
| | groups=u.shape[1], |
| | )[..., :L] |
| |
|
| | |
| | z_pre = z_pre + bias[None, :, None] |
| |
|
| | |
| | if type(padding_mask) == torch.Tensor: |
| | z_pre = z_pre * padding_mask[:, None] |
| |
|
| | if inference_params is not None: |
| | |
| | if fir_fn != torch.nn.functional.conv1d: |
| | fir_state = u[:, -fir_length + 1 :].permute(0, 2, 1) |
| | else: |
| | fir_state = u[..., -fir_length + 1 :] |
| | else: |
| | fir_state = None |
| |
|
| | return z_pre, fir_state |
| |
|
| | def parallel_iir( |
| | self, |
| | z_pre, |
| | h, |
| | D, |
| | L, |
| | poles, |
| | residues, |
| | t, |
| | dims, |
| | layer_idx, |
| | inference_params=None, |
| | prefill_style="fft", |
| | fftconv_fn=None, |
| | padding_mask=None, |
| | use_flashfft=False, |
| | column_split_hyena=False, |
| | long_fir_threshold=None, |
| | ): |
| | """Compute the output state of the short convolutional filter.""" |
| | fft_size = 2 * L |
| | hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims |
| | |
| | if column_split_hyena: |
| | z = z_pre.reshape( |
| | z_pre.shape[0], |
| | num_attention_heads, |
| | 3 * hidden_size_per_attention_head, |
| | z_pre.shape[2], |
| | ) |
| | x2, x1, v = ( |
| | z[:, :, :hidden_size_per_attention_head], |
| | z[ |
| | :, |
| | :, |
| | hidden_size_per_attention_head : 2 * hidden_size_per_attention_head, |
| | ], |
| | z[:, :, 2 * hidden_size_per_attention_head :], |
| | ) |
| | x2, x1, v = ( |
| | x2.reshape(x2.shape[0], -1, x2.shape[-1]), |
| | x1.reshape(x1.shape[0], -1, x1.shape[-1]), |
| | v.reshape(v.shape[0], -1, v.shape[-1]), |
| | ) |
| | else: |
| | x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1) |
| |
|
| | x1v = x1 * v |
| |
|
| | if inference_params is not None and prefill_style == "recurrence": |
| | y = self.prefill_via_direct_recurrence( |
| | inference_params=inference_params, |
| | x1v=x1v, |
| | L=L, |
| | poles=poles, |
| | residues=residues, |
| | ) |
| |
|
| | else: |
| | if use_flashfft and (L % 2) == 0: |
| | y = fftconv_fn( |
| | x1v.to(dtype=torch.bfloat16).contiguous(), |
| | h.to(dtype=torch.float32), |
| | ) |
| | X_s = None |
| |
|
| | elif long_fir_threshold is None: |
| | H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size |
| | X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size) |
| | X = X_s[..., : H.shape[-1]] |
| | if len(z_pre.shape) > 3: |
| | H = H.unsqueeze(1) |
| | y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L] |
| |
|
| | else: |
| | assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold" |
| | h = h[0][:, None] |
| | h = h[..., :long_fir_threshold] |
| | y = F.conv1d( |
| | x1v, |
| | h.to(dtype=x1v.dtype), |
| | stride=1, |
| | groups=x1v.shape[1], |
| | padding=h.shape[-1] - 1, |
| | )[..., :L] |
| |
|
| | y = y.to(dtype=x1v.dtype) |
| | y = (y + x1v * D.unsqueeze(-1)) * x2 |
| |
|
| | if inference_params is not None: |
| | if prefill_style == "fft": |
| | self.prefill_via_modal_fft( |
| | inference_params=inference_params, |
| | x1v=x1v, |
| | X_s=X_s, |
| | L=L, |
| | t=t, |
| | poles=poles, |
| | dims=dims, |
| | layer_idx=layer_idx, |
| | use_flashfft=use_flashfft, |
| | fftconv_fn=fftconv_fn, |
| | ) |
| |
|
| | elif prefill_style == "recurrence": |
| | |
| | pass |
| | else: |
| | raise NotImplementedError |
| | if self.low_mem_mode: |
| | |
| | del z_pre, x2, x1, v, x1v, h, poles, residues |
| | torch.cuda.empty_cache() |
| |
|
| | return y.permute(0, 2, 1) |
| |
|
| | def step_fir(self, u, fir_state, weight, bias=None): |
| | """Step the FIR filter. |
| | |
| | Note: |
| | `fir_state` contains the last `short_filter_length - 1` elements of `u`: `u_(L-2), u_{L-1), ...` |
| | We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]` (SISO / multi SISO layout). |
| | """ |
| | h0, h = weight[..., 0, -1], weight[..., 0, :-1] |
| | h0, h = h0[None], h[None] |
| | y = h0 * u + torch.sum(fir_state * h, dim=-1) + bias |
| |
|
| | |
| | fir_state = torch.roll(fir_state, -1, dims=2) |
| | fir_state[..., -1] = u |
| | return y, fir_state |
| |
|
| | def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1): |
| | x1v = x1 * v |
| |
|
| | residues, poles = ( |
| | torch.view_as_complex(residues.to(torch.float32)), |
| | torch.view_as_complex(poles.to(torch.float32)), |
| | ) |
| | |
| | |
| | residues, poles = residues[..., 0][None], poles[..., 0][None] |
| | iir_state = poles * iir_state + x1v[..., None] |
| |
|
| | res_state = torch.sum(residues * iir_state, dim=-1).real |
| |
|
| | if iir_groups > 1: |
| | raise NotImplementedError |
| | y = x2 * (res_state + D * x1v) |
| |
|
| | return y, iir_state |
| |
|
| | def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs): |
| | """Turns the IIR filter into a FIR and uses a cache for decoding.""" |
| | raise NotImplementedError(":)") |
| |
|
| | def prefill_via_direct_recurrence( |
| | self, inference_params, x1v, L, residues, poles, *args, **kwargs |
| | ) -> torch.Tensor: |
| | """ |
| | Compute the IIR state via explicit SSM recurrence (modal form) |
| | |
| | This is the most memory efficient prefilling method for Hyena filters. |
| | |
| | Note: |
| | dtypes: [state: float32, poles: float32, x1v: bfloat16, output: bfloat16] |
| | """ |
| | state_dim = poles.shape[1] |
| | x1v_ = x1v[..., None, None] |
| | x1v_ = x1v_.repeat(1, 1, 1, state_dim, 2) |
| | x1v_[..., 1] = 0 |
| |
|
| | state = 0 * x1v_[:, :, 0] |
| | output = 0 * x1v_[:, :, :, 0, 0] |
| |
|
| | |
| | poles = poles[:, :, 0][None] |
| | residues = residues[:, :, 0][None].repeat(x1v_.shape[0], 1, 1, 1) |
| |
|
| | |
| | |
| | |
| | for i in range(L): |
| | state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0] |
| | state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1] |
| | output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] |
| | |
| | inference_params.state_dict[self.layer_idx] = torch.view_as_complex(state.to(dtype=torch.float32)) |
| |
|
| | return output |
| |
|
| | def prefill_via_hybrid_recurrence(self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs): |
| | """ |
| | Compute the IIR state via hybrid recurrence-convolution over blocks |
| | """ |
| | raise NotImplementedError(":)") |
| |
|
| | def prefill_via_scan(self, u, inference_params=None, *args, **kwargs): |
| | raise NotImplementedError |
| |
|
| | def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs): |
| | """ |
| | Compute the IIR state via a single FFT with the denominator of the SSM in companion form. |
| | |
| | This is the most memory efficient "parallelized" prefilling method for Hyena. |
| | |
| | From: https://arxiv.org/abs/2310.18780 |
| | """ |
| | raise NotImplementedError(":)") |
| |
|
| | def prefill_via_modal_fft( |
| | self, |
| | inference_params, |
| | x1v, |
| | L, |
| | poles, |
| | t, |
| | dims, |
| | layer_idx, |
| | X_s=None, |
| | use_flashfft=False, |
| | fftconv_fn=None, |
| | state_dtype=torch.complex64, |
| | *args, |
| | **kwargs, |
| | ): |
| | """ |
| | Compute the IIR state via a single FFT, using the poles of the SSM in modal form. |
| | """ |
| | |
| | |
| | |
| | hidden_size, _, _, state_size, hyena_filter_groups = dims |
| |
|
| | if use_flashfft: |
| | |
| | poles = poles.squeeze().reshape(poles.shape[0], -1)[..., None] |
| |
|
| | state_s = poles**t |
| | if hyena_filter_groups > 1: |
| | raise NotImplementedError |
| |
|
| | x1v = x1v[:, :, None].repeat(1, 1, 2 * state_size, 1) |
| | x1v = x1v.reshape(x1v.shape[0], -1, x1v.shape[-1]) |
| | state_s = state_s[None] |
| |
|
| | state = fftconv_fn( |
| | x1v.contiguous(), |
| | state_s.to(dtype=torch.float32), |
| | ) |
| | state = state[..., L - 1].reshape(x1v.shape[0], hidden_size, state_size, 2) |
| | state = torch.view_as_complex(state.contiguous().to(dtype=torch.float32)) |
| | inference_params.state_dict[self.layer_idx] = state |
| | else: |
| | assert X_s is not None |
| | bs = x1v.shape[0] |
| | fft_size = 2 * L |
| | poles = torch.view_as_complex(poles.to(torch.float32)) |
| | state_s = poles**t |
| | state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) |
| | if hyena_filter_groups > 1: |
| | state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1) |
| | state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size) |
| | inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype) |
| |
|
| | def _compute_state(self, log_poles, u, t, L, *args, **kwargs): |
| | """ |
| | Compute the IIR state given an input `u` and log_poles of the modal system. |
| | """ |
| | bs = u.shape[0] |
| | fft_size = 2 * L |
| | U = torch.fft.rfft(u.to(torch.float32), n=fft_size) |
| | fft_size = 2 * L |
| | x = (log_poles * t).exp() |
| | |
| | X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1) |
| | state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L] |
| | return state |
| |
|