NATTEN Error running Python3.12, Cuda 12.8 and Pytorch 2.8 on A100

#2
by dalow24 - opened

I am receiving this error even though I have NATTEN install and I also reinstall it.


NotImplementedError Traceback (most recent call last)
Cell In[39], line 10
7 n_steps = 2
8 for step_idx in range(n_steps):
9 # Run one prognostic step with the GOES model
---> 10 y_pred, y_pred_coords = model(y, y_coords)
12 # Run one prognostic step with the MRMS model conditioned on GOES
13 y_mrms_pred, y_coords_mrms_pred = model_mrms.call_with_conditioning(
14 y_mrms, y_coords_mrms, conditioning=y, conditioning_coords=y_coords
15 )

File /usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py:120, in context_decorator..decorate_context(*args, **kwargs)
117 @functools.wraps(func)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> 120 return func(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/earth2studio/models/batch.py:220, in batch_func._batch_wrap.._wrapper(model, *args, **kwargs)
215 raise ValueError(
216 "Mismatched batched dimensions across input (x, CoordSystem) pairs"
217 )
219 # Model forward
--> 220 out, out_coords = func(model, *new_args, **kwargs)
221 out, out_coords = self._decompress_batch(
222 out, out_coords, batched_coords, batched_shape
223 )
224 return out, out_coords

File /usr/local/lib/python3.12/dist-packages/earth2studio/models/px/stormscope.py:857, in StormScopeBase.call(self, x, coords)
853 conditioning_coords = None
855 output_coords = self.output_coords(x_coords)
--> 857 x = self._forward(
858 x,
859 x_coords,
860 conditioning=conditioning,
861 conditioning_coords=conditioning_coords,
862 )
864 return x, output_coords

File /usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py:120, in context_decorator..decorate_context(*args, **kwargs)
117 @functools.wraps(func)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> 120 return func(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/earth2studio/models/px/stormscope.py:619, in StormScopeBase._forward(self, x, coords, conditioning, conditioning_coords)
614 latents = torch.randn(
615 b * t, x.shape[3:], device=x.device, dtype=x.dtype
616 ) # shape [B
T, C, H, W]
618 # Run diffusion sampler
--> 619 out = self._edm_sampler(
620 latents=latents,
621 condition=condition,
622 sigma_min=self.start_sigma,
623 sigma_max=self.end_sigma,
624 **self.sampler_args,
625 ).to(output_dtype)
627 out = out.reshape(b, t, len(self.output_times), *out.shape[1:])
629 out = torch.where(self.valid_mask, out, torch.nan)

File /usr/local/lib/python3.12/dist-packages/earth2studio/models/px/stormscope.py:683, in StormScopeBase._edm_sampler(self, latents, condition, class_labels, randn_like, num_steps, sigma_max, sigma_min, rho, S_churn, S_min, S_max, S_noise, progress_bar)
680 x_hat = x_cur + (t_hat2 - t_cur2).sqrt() * S_noise * randn_like(x_cur)
682 # Euler step.
--> 683 denoised = active_net(
684 x_hat, t_hat, class_labels=class_labels, condition=condition
685 ).to(self._SAMPLER_DTYPE)
686 d_cur = (x_hat - denoised) / t_hat
687 x_next = x_hat + (t_next - t_hat) * d_cur

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()

File /usr/local/lib/python3.12/dist-packages/earth2studio/models/nn/stormscope_util.py:172, in EDMPrecond.forward(self, x, sigma, condition, class_labels, return_logvar, force_fp32, training, **model_kwargs)
170 arg = torch.cat([arg, condition], dim=1)
171 # now we have added the p_dropout probability to the model.
--> 172 F_x = self.model(
173 (arg).to(dtype),
174 c_noise.flatten(),
175 p_dropout=p_dropout,
176 training=training,
177 **model_kwargs,
178 )
180 D_x = c_skip * x + c_out * F_x.to(torch.float32)
182 if return_logvar:

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()

File /usr/local/lib/python3.12/dist-packages/earth2studio/models/nn/stormscope_util.py:242, in DropInDiT.forward(self, x, time_step_cond, label_cond, points, p_dropout, training)
239 attn_kwargs: dict[str, Any] | None = {"latent_hw": latent_hw}
241 # Note: points / cross-attention are not supported in PhysicsNeMo DiT so we ignore them
--> 242 out = self.pnm(
243 x=x,
244 t=time_step_cond,
245 condition=condition,
246 p_dropout=p_dropout,
247 attn_kwargs=attn_kwargs,
248 )
249 return out

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()

File /usr/local/lib/python3.12/dist-packages/physicsnemo/experimental/models/dit/dit.py:321, in DiT.forward(self, x, t, condition, p_dropout, attn_kwargs)
318 c = t # (B, D)
320 for block in self.blocks:
--> 321 x = block(x, c, p_dropout=p_dropout, attn_kwargs=attn_kwargs) # (B, L, D)
323 # De-tokenize: (B, L, D) -> (B, C, H, W)
324 if self.force_tokenization_fp32:

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()

File /usr/local/lib/python3.12/dist-packages/physicsnemo/experimental/models/dit/layers.py:591, in DiTBlock.forward(self, x, c, attn_kwargs, p_dropout)
588 elif p_dropout is not None:
589 raise ValueError("p_dropout passed to DiTBlock but intermediate_dropout is disabled")
--> 591 attention_output = self.attention(
592 modulated_attn_input,
593 **(attn_kwargs or {}),
594 )
595 x = x + attention_gate.unsqueeze(1) * attention_output
597 # Feed-forward block

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()

File /usr/local/lib/python3.12/dist-packages/physicsnemo/experimental/models/dit/layers.py:367, in Natten2DSelfAttention.forward(self, x, latent_hw)
365 x = partial_na2d(q, k, v, kernel_size=self.attn_kernel, base_func=na2d, dilation=1)
366 else:
--> 367 x = na2d(q, k, v, kernel_size=self.attn_kernel)
368 x = self.attn_drop(x)
369 x = rearrange(x, "b h w head c -> b (h w) (head c)")

File /usr/local/lib/python3.12/dist-packages/natten/functional.py:833, in na2d(query, key, value, kernel_size, stride, dilation, is_causal, scale, additional_keys, additional_values, attention_kwargs, backend, q_tile_shape, kv_tile_shape, backward_q_tile_shape, backward_kv_tile_shape, backward_kv_splits, backward_use_pt_reduction, run_persistent_kernel, kernel_schedule, torch_compile)
671 def na2d(
672 query: Tensor,
673 key: Tensor,
(...) 692 torch_compile: bool = False,
693 ) -> Tensor:
694 """Computes 2-D neighborhood attention.
695
696 Args:
(...) 831 ([batch, X, Y, heads, head_dim]).
832 """
--> 833 return neighborhood_attention_generic(
834 query=query,
835 key=key,
836 value=value,
837 kernel_size=kernel_size,
838 stride=stride,
839 dilation=dilation,
840 is_causal=is_causal,
841 scale=scale,
842 additional_keys=additional_keys,
843 additional_values=additional_values,
844 attention_kwargs=attention_kwargs,
845 backend=backend,
846 q_tile_shape=q_tile_shape,
847 kv_tile_shape=kv_tile_shape,
848 backward_q_tile_shape=backward_q_tile_shape,
849 backward_kv_tile_shape=backward_kv_tile_shape,
850 backward_kv_splits=backward_kv_splits,
851 backward_use_pt_reduction=backward_use_pt_reduction,
852 run_persistent_kernel=run_persistent_kernel,
853 kernel_schedule=kernel_schedule,
854 torch_compile=torch_compile,
855 )

File /usr/local/lib/python3.12/dist-packages/natten/functional.py:387, in neighborhood_attention_generic(query, key, value, kernel_size, stride, dilation, is_causal, scale, additional_keys, additional_values, attention_kwargs, backend, q_tile_shape, kv_tile_shape, backward_q_tile_shape, backward_kv_tile_shape, backward_kv_splits, backward_use_pt_reduction, run_persistent_kernel, kernel_schedule, torch_compile)
383 return out.reshape(*output_shape)
385 scale = scale or query.shape[-1] ** -0.5
--> 387 backend = backend or choose_backend(query, key, value, torch_compile=torch_compile)
389 has_additional_attention = (
390 additional_keys is not None and additional_values is not None
391 )
393 if backend == "blackwell-fna":

File /usr/local/lib/python3.12/dist-packages/natten/backends/init.py:104, in choose_backend(query, key, value, torch_compile)
101 logger.debug("Backend not set; picked Flex Attention kernel.")
102 return "flex-fna"
--> 104 raise NotImplementedError(
105 "NATTEN could not find a suitable backend for this use case. "
106 "Run with NATTEN_LOG_LEVEL=DEBUG to find out why."
107 )

NotImplementedError: NATTEN could not find a suitable backend for this use case. Run with NATTEN_LOG_LEVEL=DEBUG to find out why.

dalow24 changed discussion title from Error running Python3.12, Cuda 12.8 and Pytorch 2.8 to Error running Python3.12, Cuda 12.8 and Pytorch 2.8 on A100
dalow24 changed discussion title from Error running Python3.12, Cuda 12.8 and Pytorch 2.8 on A100 to NATTEN Error running Python3.12, Cuda 12.8 and Pytorch 2.8 on A100

Sign up or log in to comment