import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from pathlib import Path import json import time from typing import Dict, Any class StandardModel(nn.Module): """Base class for standard PyTorch models""" def __init__(self): super().__init__() self.model_name = self.__class__.__name__ self.training_history = [] def forward(self, x): raise NotImplementedError def save_checkpoint(self, epoch: int, loss: float, optimizer_state: Dict, save_dir: str = "models"): """Save model checkpoint in standard format""" save_path = Path(save_dir) save_path.mkdir(exist_ok=True) checkpoint = { 'model_name': self.model_name, 'epoch': epoch, 'model_state_dict': self.state_dict(), 'optimizer_state_dict': optimizer_state, 'loss': loss, 'timestamp': time.time() } filename = f"{self.model_name}_epoch_{epoch}.pth" torch.save(checkpoint, save_path / filename) # Also save training history with open(save_path / f"{self.model_name}_history.json", 'w') as f: json.dump(self.training_history, f, indent=2) def load_checkpoint(self, checkpoint_path: str): """Load model checkpoint""" checkpoint = torch.load(checkpoint_path) self.load_state_dict(checkpoint['model_state_dict']) return checkpoint['epoch'], checkpoint['loss'] class SimpleCNN(StandardModel): """Simple CNN for image classification""" def __init__(self, num_classes: int = 10): super().__init__() self.num_classes = num_classes self.features = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)) ) self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.5), nn.Linear(64, num_classes) ) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x class Trainer: """Standard training loop""" def __init__(self, model: StandardModel, device: str = "cpu"): self.model = model.to(device) self.device = device self.criterion = nn.CrossEntropyLoss() self.optimizer = optim.Adam(model.parameters(), lr=0.001) def train_epoch(self, dataloader: DataLoader, epoch: int): """Train for one epoch""" self.model.train() running_loss = 0.0 correct = 0 total = 0 for batch_idx, (data, targets) in enumerate(dataloader): data, targets = data.to(self.device), targets.to(self.device) self.optimizer.zero_grad() outputs = self.model(data) loss = self.criterion(outputs, targets) loss.backward() self.optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() if batch_idx % 100 == 0: print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}') epoch_loss = running_loss / len(dataloader) epoch_acc = 100. * correct / total # Record training history self.model.training_history.append({ 'epoch': epoch, 'loss': epoch_loss, 'accuracy': epoch_acc }) return epoch_loss, epoch_acc def train(self, dataloader: DataLoader, epochs: int, save_dir: str = "models"): """Full training loop""" best_loss = float('inf') for epoch in range(epochs): loss, acc = self.train_epoch(dataloader, epoch) print(f'Epoch {epoch}: Loss {loss:.4f}, Accuracy {acc:.2f}%') # Save best model if loss < best_loss: best_loss = loss self.model.save_checkpoint( epoch, loss, self.optimizer.state_dict(), save_dir ) print(f'Saved best model at epoch {epoch}') return self.model.training_history if __name__ == "__main__": # Example usage model = SimpleCNN(num_classes=10) trainer = Trainer(model) print(f"Model: {model.model_name}") print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") # This would be used with a real dataloader # history = trainer.train(dataloader, epochs=10)