| | """ |
| | train_digit_classifier.py |
| | |
| | A fully documented training script for a convolutional neural network (CNN) |
| | classifier trained on MNIST + EMNIST digits + blank images. |
| | |
| | Author: Deep Shah |
| | License: GPL-3.0 |
| | """ |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import torchvision |
| | import torchvision.transforms as transforms |
| | from torch.utils.data import DataLoader, Dataset, TensorDataset |
| | from sklearn.model_selection import train_test_split |
| | import os |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | torch.manual_seed(42) |
| | np.random.seed(42) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"[INFO] Using device: {device}") |
| |
|
| | |
| | |
| | |
| |
|
| | class EMNISTDigitsDataset(Dataset): |
| | """ |
| | A PyTorch-compatible wrapper for the EMNIST digits dataset loaded via TensorFlow Datasets. |
| | Ensures data is shaped correctly and optionally transformed. |
| | """ |
| |
|
| | def __init__(self, split="train", transform=None): |
| | import tensorflow_datasets as tfds |
| | ds = tfds.load("emnist/digits", split=split, as_supervised=True) |
| | self.images = [] |
| | self.labels = [] |
| | for image, label in tfds.as_numpy(ds): |
| | if image.ndim == 2: |
| | image = image[..., np.newaxis] |
| | elif image.ndim == 4 and image.shape[0] == 1: |
| | image = image[0] |
| | self.images.append(image) |
| | self.labels.append(label) |
| | self.images = np.array(self.images, dtype=np.float32) / 255.0 |
| | self.labels = np.array(self.labels, dtype=np.int64) |
| | self.transform = transform |
| |
|
| | def __len__(self): |
| | return len(self.images) |
| |
|
| | def __getitem__(self, idx): |
| | image = self.images[idx] |
| | label = self.labels[idx] |
| | if self.transform: |
| | image = self.transform(torch.tensor(image.transpose(2, 0, 1))).transpose(1, 2).numpy() |
| | return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32), torch.tensor(label, dtype=torch.long) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | train_transform = transforms.Compose([ |
| | transforms.ToPILImage(), |
| | transforms.RandomRotation(10), |
| | transforms.RandomAffine(degrees=0, scale=(0.9, 1.1), translate=(0.1, 0.1)), |
| | transforms.ToTensor() |
| | ]) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | mnist_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True) |
| | mnist_images = mnist_dataset.data.numpy().astype(np.float32) / 255.0 |
| | mnist_images = mnist_images[..., np.newaxis] |
| | mnist_labels = mnist_dataset.targets.numpy() |
| |
|
| | |
| | emnist_dataset = EMNISTDigitsDataset(split="train", transform=None) |
| | emnist_images = emnist_dataset.images |
| | emnist_labels = emnist_dataset.labels |
| |
|
| | |
| | x_blank = np.zeros((5000, 28, 28, 1), dtype=np.float32) |
| | y_blank = np.full((5000,), 10, dtype=np.int64) |
| |
|
| | |
| | x_combined = np.concatenate([mnist_images, emnist_images, x_blank], axis=0) |
| | y_combined = np.concatenate([mnist_labels, emnist_labels, y_blank], axis=0) |
| |
|
| | |
| | indices = np.random.permutation(len(x_combined)) |
| | x_combined = x_combined[indices] |
| | y_combined = y_combined[indices] |
| |
|
| | |
| | |
| | |
| |
|
| | x_train, x_val, y_train, y_val = train_test_split( |
| | x_combined, y_combined, test_size=0.1, random_state=42 |
| | ) |
| |
|
| | |
| | train_dataset = TensorDataset( |
| | torch.tensor(x_train.transpose(0, 3, 1, 2), dtype=torch.float32), |
| | torch.tensor(y_train, dtype=torch.long) |
| | ) |
| | val_dataset = TensorDataset( |
| | torch.tensor(x_val.transpose(0, 3, 1, 2), dtype=torch.float32), |
| | torch.tensor(y_val, dtype=torch.long) |
| | ) |
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) |
| | val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) |
| |
|
| | |
| | |
| | |
| |
|
| | class CNN(nn.Module): |
| | """ |
| | This CNN is designed to: |
| | - Use 3 convolutional blocks with increasing depth (32 -> 64 -> 128) |
| | - Use BatchNorm to stabilize training |
| | - Use Dropout to prevent overfitting |
| | - Flatten and use 2 dense layers to classify |
| | """ |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.features = nn.Sequential( |
| | nn.Conv2d(1, 32, 3, padding=1), |
| | nn.BatchNorm2d(32), |
| | nn.ReLU(), |
| | nn.Conv2d(32, 64, 3, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2, 2), |
| | nn.Dropout(0.1), |
| | nn.Conv2d(64, 128, 3, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2, 2), |
| | nn.Dropout(0.1) |
| | ) |
| | self.classifier = nn.Sequential( |
| | nn.Flatten(), |
| | nn.Linear(128 * 7 * 7, 128), |
| | nn.BatchNorm1d(128), |
| | nn.ReLU(), |
| | nn.Dropout(0.2), |
| | nn.Linear(128, 11) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.classifier(self.features(x)) |
| |
|
| | model = CNN().to(device) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | criterion = nn.CrossEntropyLoss() |
| |
|
| | |
| | optimizer = optim.Adam(model.parameters(), lr=0.001) |
| |
|
| | |
| | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6) |
| |
|
| | |
| | patience = 5 |
| | patience_counter = 0 |
| | best_val_loss = float("inf") |
| | best_model_state = None |
| |
|
| | |
| | |
| | |
| |
|
| | for epoch in range(1, 51): |
| | model.train() |
| | running_loss = 0 |
| | correct = 0 |
| | total = 0 |
| |
|
| | for images, labels in train_loader: |
| | images, labels = images.to(device), labels.to(device) |
| |
|
| | |
| | for i in range(len(images)): |
| | images[i] = train_transform(images[i].cpu()).to(device) |
| |
|
| | optimizer.zero_grad() |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | running_loss += loss.item() |
| | _, predicted = torch.max(outputs, 1) |
| | total += labels.size(0) |
| | correct += (predicted == labels).sum().item() |
| |
|
| | train_acc = 100 * correct / total |
| | train_loss = running_loss / len(train_loader) |
| |
|
| | |
| | |
| | |
| | model.eval() |
| | val_loss = 0 |
| | val_correct = 0 |
| | val_total = 0 |
| | with torch.no_grad(): |
| | for images, labels in val_loader: |
| | images, labels = images.to(device), labels.to(device) |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | val_loss += loss.item() |
| | _, predicted = torch.max(outputs, 1) |
| | val_total += labels.size(0) |
| | val_correct += (predicted == labels).sum().item() |
| |
|
| | val_acc = 100 * val_correct / val_total |
| | val_loss /= len(val_loader) |
| |
|
| | print(f"Epoch {epoch:02d}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, " |
| | f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%") |
| |
|
| | |
| | scheduler.step(val_loss) |
| |
|
| | |
| | if val_loss < best_val_loss: |
| | best_val_loss = val_loss |
| | best_model_state = model.state_dict() |
| | patience_counter = 0 |
| | else: |
| | patience_counter += 1 |
| | if patience_counter >= patience: |
| | print("[INFO] Early stopping triggered.") |
| | break |
| |
|
| | |
| | model.load_state_dict(best_model_state) |
| |
|
| | |
| | torch.save(model.state_dict(), "mnist_emnist_blank_cnn_v1.pth") |
| | print("[INFO] Model weights saved as mnist_emnist_blank_cnn_v1.pth") |
| |
|
| | |
| | model.eval() |
| | example_input = torch.randn(1, 1, 28, 28).to(device) |
| | scripted_model = torch.jit.trace(model, example_input) |
| | scripted_model.save("mnist_emnist_blank_cnn_v1.pt") |
| | print("[INFO] TorchScript model saved as mnist_emnist_blank_cnn_v1.pt") |
| |
|
| | |
| | |
| | prev_device = next(model.parameters()).device |
| | try: |
| | model_cpu = model.to("cpu").eval() |
| | dummy = torch.randn(1, 1, 28, 28) |
| |
|
| | onnx_path = "mnist_emnist_blank_cnn_v1.onnx" |
| | torch.onnx.export( |
| | model_cpu, |
| | dummy, |
| | onnx_path, |
| | export_params=True, |
| | opset_version=13, |
| | do_constant_folding=True, |
| | input_names=["input"], |
| | output_names=["logits"], |
| | dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}, |
| | ) |
| | print(f"[INFO] ONNX model saved as {onnx_path}") |
| | finally: |
| | model.to(prev_device).eval() |
| |
|