import torch import torch.nn as nn class EuroSATCNN(nn.Module): def __init__(self, num_classes, img_height=64, img_width=64): super(EuroSATCNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(13, 128, kernel_size=4, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(128, 64, kernel_size=4, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(64, 32, kernel_size=4, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(32, 16, kernel_size=4, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) with torch.no_grad(): dummy_input = torch.randn(1, 13, img_height, img_width) out = self.features(dummy_input) fc1_input_size = out.view(1, -1).shape[1] self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(fc1_input_size, 64), nn.ReLU(), nn.Linear(64, num_classes) ) def forward(self, x): x = self.features(x) x = self.classifier(x) return x