Removed timm import
Browse files- modeling_dass.py +34 -2
modeling_dass.py
CHANGED
|
@@ -10,7 +10,6 @@ import warnings
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
import torch.utils.checkpoint as checkpoint
|
| 13 |
-
from timm.models.layers import DropPath, trunc_normal_
|
| 14 |
from functools import partial
|
| 15 |
from typing import Optional, Callable, Any, Union
|
| 16 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
|
@@ -718,6 +717,39 @@ def selective_scan_fn(
|
|
| 718 |
############## HuggingFace modeling file #################
|
| 719 |
##########################################################
|
| 720 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
class DASSLinear2d(nn.Linear):
|
| 722 |
def __init__(self, *args, groups=1, **kwargs):
|
| 723 |
nn.Linear.__init__(self, *args, **kwargs)
|
|
@@ -1094,7 +1126,7 @@ class DASSPreTrainedModel(PreTrainedModel):
|
|
| 1094 |
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 1095 |
"""Initialize the weights"""
|
| 1096 |
if isinstance(module, nn.Linear):
|
| 1097 |
-
trunc_normal_(module.weight, std=0.02)
|
| 1098 |
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 1099 |
nn.init.constant_(module.bias, 0)
|
| 1100 |
elif isinstance(module, nn.LayerNorm):
|
|
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
import torch.utils.checkpoint as checkpoint
|
|
|
|
| 13 |
from functools import partial
|
| 14 |
from typing import Optional, Callable, Any, Union
|
| 15 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
|
|
|
| 717 |
############## HuggingFace modeling file #################
|
| 718 |
##########################################################
|
| 719 |
|
| 720 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
| 721 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 722 |
+
|
| 723 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 724 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 725 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 726 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 727 |
+
'survival rate' as the argument.
|
| 728 |
+
|
| 729 |
+
"""
|
| 730 |
+
if drop_prob == 0. or not training:
|
| 731 |
+
return x
|
| 732 |
+
keep_prob = 1 - drop_prob
|
| 733 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 734 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 735 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 736 |
+
random_tensor.div_(keep_prob)
|
| 737 |
+
return x * random_tensor
|
| 738 |
+
|
| 739 |
+
class DropPath(nn.Module):
|
| 740 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 741 |
+
"""
|
| 742 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 743 |
+
super(DropPath, self).__init__()
|
| 744 |
+
self.drop_prob = drop_prob
|
| 745 |
+
self.scale_by_keep = scale_by_keep
|
| 746 |
+
|
| 747 |
+
def forward(self, x):
|
| 748 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 749 |
+
|
| 750 |
+
def extra_repr(self):
|
| 751 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
| 752 |
+
|
| 753 |
class DASSLinear2d(nn.Linear):
|
| 754 |
def __init__(self, *args, groups=1, **kwargs):
|
| 755 |
nn.Linear.__init__(self, *args, **kwargs)
|
|
|
|
| 1126 |
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 1127 |
"""Initialize the weights"""
|
| 1128 |
if isinstance(module, nn.Linear):
|
| 1129 |
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 1130 |
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 1131 |
nn.init.constant_(module.bias, 0)
|
| 1132 |
elif isinstance(module, nn.LayerNorm):
|