NATTEN Error running Python3.12, Cuda 12.8 and Pytorch 2.8 on A100
I am receiving this error even though I have NATTEN install and I also reinstall it.
NotImplementedError Traceback (most recent call last)
Cell In[39], line 10
7 n_steps = 2
8 for step_idx in range(n_steps):
9 # Run one prognostic step with the GOES model
---> 10 y_pred, y_pred_coords = model(y, y_coords)
12 # Run one prognostic step with the MRMS model conditioned on GOES
13 y_mrms_pred, y_coords_mrms_pred = model_mrms.call_with_conditioning(
14 y_mrms, y_coords_mrms, conditioning=y, conditioning_coords=y_coords
15 )
File /usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py:120, in context_decorator..decorate_context(*args, **kwargs)
117 @functools.wraps(func)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> 120 return func(*args, **kwargs)
File /usr/local/lib/python3.12/dist-packages/earth2studio/models/batch.py:220, in batch_func._batch_wrap.._wrapper(model, *args, **kwargs)
215 raise ValueError(
216 "Mismatched batched dimensions across input (x, CoordSystem) pairs"
217 )
219 # Model forward
--> 220 out, out_coords = func(model, *new_args, **kwargs)
221 out, out_coords = self._decompress_batch(
222 out, out_coords, batched_coords, batched_shape
223 )
224 return out, out_coords
File /usr/local/lib/python3.12/dist-packages/earth2studio/models/px/stormscope.py:857, in StormScopeBase.call(self, x, coords)
853 conditioning_coords = None
855 output_coords = self.output_coords(x_coords)
--> 857 x = self._forward(
858 x,
859 x_coords,
860 conditioning=conditioning,
861 conditioning_coords=conditioning_coords,
862 )
864 return x, output_coords
File /usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py:120, in context_decorator..decorate_context(*args, **kwargs)
117 @functools.wraps(func)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> 120 return func(*args, **kwargs)
File /usr/local/lib/python3.12/dist-packages/earth2studio/models/px/stormscope.py:619, in StormScopeBase._forward(self, x, coords, conditioning, conditioning_coords)
614 latents = torch.randn(
615 b * t, x.shape[3:], device=x.device, dtype=x.dtype
616 ) # shape [BT, C, H, W]
618 # Run diffusion sampler
--> 619 out = self._edm_sampler(
620 latents=latents,
621 condition=condition,
622 sigma_min=self.start_sigma,
623 sigma_max=self.end_sigma,
624 **self.sampler_args,
625 ).to(output_dtype)
627 out = out.reshape(b, t, len(self.output_times), *out.shape[1:])
629 out = torch.where(self.valid_mask, out, torch.nan)
File /usr/local/lib/python3.12/dist-packages/earth2studio/models/px/stormscope.py:683, in StormScopeBase._edm_sampler(self, latents, condition, class_labels, randn_like, num_steps, sigma_max, sigma_min, rho, S_churn, S_min, S_max, S_noise, progress_bar)
680 x_hat = x_cur + (t_hat2 - t_cur2).sqrt() * S_noise * randn_like(x_cur)
682 # Euler step.
--> 683 denoised = active_net(
684 x_hat, t_hat, class_labels=class_labels, condition=condition
685 ).to(self._SAMPLER_DTYPE)
686 d_cur = (x_hat - denoised) / t_hat
687 x_next = x_hat + (t_next - t_hat) * d_cur
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File /usr/local/lib/python3.12/dist-packages/earth2studio/models/nn/stormscope_util.py:172, in EDMPrecond.forward(self, x, sigma, condition, class_labels, return_logvar, force_fp32, training, **model_kwargs)
170 arg = torch.cat([arg, condition], dim=1)
171 # now we have added the p_dropout probability to the model.
--> 172 F_x = self.model(
173 (arg).to(dtype),
174 c_noise.flatten(),
175 p_dropout=p_dropout,
176 training=training,
177 **model_kwargs,
178 )
180 D_x = c_skip * x + c_out * F_x.to(torch.float32)
182 if return_logvar:
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File /usr/local/lib/python3.12/dist-packages/earth2studio/models/nn/stormscope_util.py:242, in DropInDiT.forward(self, x, time_step_cond, label_cond, points, p_dropout, training)
239 attn_kwargs: dict[str, Any] | None = {"latent_hw": latent_hw}
241 # Note: points / cross-attention are not supported in PhysicsNeMo DiT so we ignore them
--> 242 out = self.pnm(
243 x=x,
244 t=time_step_cond,
245 condition=condition,
246 p_dropout=p_dropout,
247 attn_kwargs=attn_kwargs,
248 )
249 return out
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File /usr/local/lib/python3.12/dist-packages/physicsnemo/experimental/models/dit/dit.py:321, in DiT.forward(self, x, t, condition, p_dropout, attn_kwargs)
318 c = t # (B, D)
320 for block in self.blocks:
--> 321 x = block(x, c, p_dropout=p_dropout, attn_kwargs=attn_kwargs) # (B, L, D)
323 # De-tokenize: (B, L, D) -> (B, C, H, W)
324 if self.force_tokenization_fp32:
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File /usr/local/lib/python3.12/dist-packages/physicsnemo/experimental/models/dit/layers.py:591, in DiTBlock.forward(self, x, c, attn_kwargs, p_dropout)
588 elif p_dropout is not None:
589 raise ValueError("p_dropout passed to DiTBlock but intermediate_dropout is disabled")
--> 591 attention_output = self.attention(
592 modulated_attn_input,
593 **(attn_kwargs or {}),
594 )
595 x = x + attention_gate.unsqueeze(1) * attention_output
597 # Feed-forward block
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File /usr/local/lib/python3.12/dist-packages/physicsnemo/experimental/models/dit/layers.py:367, in Natten2DSelfAttention.forward(self, x, latent_hw)
365 x = partial_na2d(q, k, v, kernel_size=self.attn_kernel, base_func=na2d, dilation=1)
366 else:
--> 367 x = na2d(q, k, v, kernel_size=self.attn_kernel)
368 x = self.attn_drop(x)
369 x = rearrange(x, "b h w head c -> b (h w) (head c)")
File /usr/local/lib/python3.12/dist-packages/natten/functional.py:833, in na2d(query, key, value, kernel_size, stride, dilation, is_causal, scale, additional_keys, additional_values, attention_kwargs, backend, q_tile_shape, kv_tile_shape, backward_q_tile_shape, backward_kv_tile_shape, backward_kv_splits, backward_use_pt_reduction, run_persistent_kernel, kernel_schedule, torch_compile)
671 def na2d(
672 query: Tensor,
673 key: Tensor,
(...) 692 torch_compile: bool = False,
693 ) -> Tensor:
694 """Computes 2-D neighborhood attention.
695
696 Args:
(...) 831 ([batch, X, Y, heads, head_dim]).
832 """
--> 833 return neighborhood_attention_generic(
834 query=query,
835 key=key,
836 value=value,
837 kernel_size=kernel_size,
838 stride=stride,
839 dilation=dilation,
840 is_causal=is_causal,
841 scale=scale,
842 additional_keys=additional_keys,
843 additional_values=additional_values,
844 attention_kwargs=attention_kwargs,
845 backend=backend,
846 q_tile_shape=q_tile_shape,
847 kv_tile_shape=kv_tile_shape,
848 backward_q_tile_shape=backward_q_tile_shape,
849 backward_kv_tile_shape=backward_kv_tile_shape,
850 backward_kv_splits=backward_kv_splits,
851 backward_use_pt_reduction=backward_use_pt_reduction,
852 run_persistent_kernel=run_persistent_kernel,
853 kernel_schedule=kernel_schedule,
854 torch_compile=torch_compile,
855 )
File /usr/local/lib/python3.12/dist-packages/natten/functional.py:387, in neighborhood_attention_generic(query, key, value, kernel_size, stride, dilation, is_causal, scale, additional_keys, additional_values, attention_kwargs, backend, q_tile_shape, kv_tile_shape, backward_q_tile_shape, backward_kv_tile_shape, backward_kv_splits, backward_use_pt_reduction, run_persistent_kernel, kernel_schedule, torch_compile)
383 return out.reshape(*output_shape)
385 scale = scale or query.shape[-1] ** -0.5
--> 387 backend = backend or choose_backend(query, key, value, torch_compile=torch_compile)
389 has_additional_attention = (
390 additional_keys is not None and additional_values is not None
391 )
393 if backend == "blackwell-fna":
File /usr/local/lib/python3.12/dist-packages/natten/backends/init.py:104, in choose_backend(query, key, value, torch_compile)
101 logger.debug("Backend not set; picked Flex Attention kernel.")
102 return "flex-fna"
--> 104 raise NotImplementedError(
105 "NATTEN could not find a suitable backend for this use case. "
106 "Run with NATTEN_LOG_LEVEL=DEBUG to find out why."
107 )
NotImplementedError: NATTEN could not find a suitable backend for this use case. Run with NATTEN_LOG_LEVEL=DEBUG to find out why.