Rhodham96 commited on
Commit
db0c75d
·
verified ·
1 Parent(s): 8f070f2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -38
README.md CHANGED
@@ -84,50 +84,16 @@ This model can be used for automated land cover classification of Sentinel-2 sat
84
  import torch
85
  import torch.nn as nn
86
 
87
- class EuroSATCNN(nn.Module):
88
- def __init__(self, num_classes, img_height=64, img_width=64):
89
- super(EuroSATCNN, self).__init__()
90
- self.features = nn.Sequential(
91
- nn.Conv2d(13, 128, kernel_size=4, padding=1),
92
- nn.ReLU(),
93
- nn.MaxPool2d(kernel_size=2),
94
-
95
- nn.Conv2d(128, 64, kernel_size=4, padding=1),
96
- nn.ReLU(),
97
- nn.MaxPool2d(kernel_size=2),
98
-
99
- nn.Conv2d(64, 32, kernel_size=4, padding=1),
100
- nn.ReLU(),
101
- nn.MaxPool2d(kernel_size=2),
102
-
103
- nn.Conv2d(32, 16, kernel_size=4, padding=1),
104
- nn.ReLU(),
105
- nn.MaxPool2d(kernel_size=2),
106
- )
107
-
108
- with torch.no_grad():
109
- dummy_input = torch.randn(1, 13, img_height, img_width)
110
- out = self.features(dummy_input)
111
- fc1_input_size = out.view(1, -1).shape[1]
112
-
113
- self.classifier = nn.Sequential(
114
- nn.Flatten(),
115
- nn.Linear(fc1_input_size, 64),
116
- nn.ReLU(),
117
- nn.Linear(64, num_classes)
118
- )
119
-
120
- def forward(self, x):
121
- x = self.features(x)
122
- x = self.classifier(x)
123
- return x
124
 
125
  # Example usage:
126
  # Assuming num_classes is known, e.g., 10 for EuroSAT
127
  # model = EuroSATCNN(num_classes=10)
 
128
  # dummy_input_image = torch.randn(1, 13, 64, 64) # Batch size 1, 13 channels, 64x64
129
  # output = model(dummy_input_image)
130
- # print(output.shape) # Should be torch.Size([1, 10]) if num_classes=10
131
 
132
 
133
  ---
 
84
  import torch
85
  import torch.nn as nn
86
 
87
+ from model_def import EuroSATCNN
88
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Example usage:
91
  # Assuming num_classes is known, e.g., 10 for EuroSAT
92
  # model = EuroSATCNN(num_classes=10)
93
+ # model.load_state_dict(torch.load("pytorch_model.bin"))
94
  # dummy_input_image = torch.randn(1, 13, 64, 64) # Batch size 1, 13 channels, 64x64
95
  # output = model(dummy_input_image)
96
+ # print(output.shape) # Should be torch.Size([1, 10]) if num_classes=20
97
 
98
 
99
  ---