fetch_ml/podman/workspace/pytorch_project/train.py
Jeremie Fraeys 4aecd469a1 feat: implement comprehensive monitoring and container orchestration
- Add Prometheus, Grafana, and Loki monitoring stack
- Include pre-configured dashboards for ML metrics and logs
- Add Podman container support with security policies
- Implement ML runtime environments for multiple frameworks
- Add containerized ML project templates (PyTorch, TensorFlow, etc.)
- Include secure runner with isolation and resource limits
- Add comprehensive log aggregation and alerting
2025-12-04 16:54:49 -05:00

58 lines
2.1 KiB
Python

import torch
import argparse
from pathlib import Path
import sys
import os
# Add src to path for imports
sys.path.append(str(Path(__file__).parent.parent / "src"))
from data_loader import get_dataloader
from model import SimpleCNN, Trainer
from torchvision import transforms
def main():
parser = argparse.ArgumentParser(description="Standard PyTorch Training Script")
parser.add_argument("--dataset", type=str, default="cifar10",
help="Dataset name (must be registered)")
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size")
parser.add_argument("--save-dir", type=str, default="models", help="Model save directory")
parser.add_argument("--device", type=str, default="cpu", help="Device (cpu/cuda)")
args = parser.parse_args()
# Standard data transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
print(f"Loading dataset: {args.dataset}")
try:
dataloader = get_dataloader(args.dataset, batch_size=args.batch_size, transform=transform)
print(f"Dataset loaded successfully")
except Exception as e:
print(f"Error loading dataset: {e}")
print("Make sure the dataset is registered with: ml dataset register <name> <url>")
return
# Initialize model
model = SimpleCNN(num_classes=10) # CIFAR-10 has 10 classes
print(f"Model: {model.model_name}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
# Initialize trainer
trainer = Trainer(model, device=args.device)
# Train model
print(f"Starting training for {args.epochs} epochs...")
history = trainer.train(dataloader, epochs=args.epochs, save_dir=args.save_dir)
print("Training completed!")
print(f"Final loss: {history[-1]['loss']:.4f}")
print(f"Final accuracy: {history[-1]['accuracy']:.2f}%")
print(f"Models saved to: {args.save_dir}/")
if __name__ == "__main__":
main()