import torch import argparse from pathlib import Path import sys import os # Add src to path for imports sys.path.append(str(Path(__file__).parent.parent / "src")) from data_loader import get_dataloader from model import SimpleCNN, Trainer from torchvision import transforms def main(): parser = argparse.ArgumentParser(description="Standard PyTorch Training Script") parser.add_argument("--dataset", type=str, default="cifar10", help="Dataset name (must be registered)") parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") parser.add_argument("--batch-size", type=int, default=32, help="Batch size") parser.add_argument("--save-dir", type=str, default="models", help="Model save directory") parser.add_argument("--device", type=str, default="cpu", help="Device (cpu/cuda)") args = parser.parse_args() # Standard data transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) print(f"Loading dataset: {args.dataset}") try: dataloader = get_dataloader(args.dataset, batch_size=args.batch_size, transform=transform) print(f"Dataset loaded successfully") except Exception as e: print(f"Error loading dataset: {e}") print("Make sure the dataset is registered with: ml dataset register ") return # Initialize model model = SimpleCNN(num_classes=10) # CIFAR-10 has 10 classes print(f"Model: {model.model_name}") print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") # Initialize trainer trainer = Trainer(model, device=args.device) # Train model print(f"Starting training for {args.epochs} epochs...") history = trainer.train(dataloader, epochs=args.epochs, save_dir=args.save_dir) print("Training completed!") print(f"Final loss: {history[-1]['loss']:.4f}") print(f"Final accuracy: {history[-1]['accuracy']:.2f}%") print(f"Models saved to: {args.save_dir}/") if __name__ == "__main__": main()