- Add Prometheus, Grafana, and Loki monitoring stack - Include pre-configured dashboards for ML metrics and logs - Add Podman container support with security policies - Implement ML runtime environments for multiple frameworks - Add containerized ML project templates (PyTorch, TensorFlow, etc.) - Include secure runner with isolation and resource limits - Add comprehensive log aggregation and alerting
58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
import torch
|
|
import argparse
|
|
from pathlib import Path
|
|
import sys
|
|
import os
|
|
|
|
# Add src to path for imports
|
|
sys.path.append(str(Path(__file__).parent.parent / "src"))
|
|
|
|
from data_loader import get_dataloader
|
|
from model import SimpleCNN, Trainer
|
|
from torchvision import transforms
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Standard PyTorch Training Script")
|
|
parser.add_argument("--dataset", type=str, default="cifar10",
|
|
help="Dataset name (must be registered)")
|
|
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
|
|
parser.add_argument("--batch-size", type=int, default=32, help="Batch size")
|
|
parser.add_argument("--save-dir", type=str, default="models", help="Model save directory")
|
|
parser.add_argument("--device", type=str, default="cpu", help="Device (cpu/cuda)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Standard data transforms
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
|
|
print(f"Loading dataset: {args.dataset}")
|
|
try:
|
|
dataloader = get_dataloader(args.dataset, batch_size=args.batch_size, transform=transform)
|
|
print(f"Dataset loaded successfully")
|
|
except Exception as e:
|
|
print(f"Error loading dataset: {e}")
|
|
print("Make sure the dataset is registered with: ml dataset register <name> <url>")
|
|
return
|
|
|
|
# Initialize model
|
|
model = SimpleCNN(num_classes=10) # CIFAR-10 has 10 classes
|
|
print(f"Model: {model.model_name}")
|
|
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
# Initialize trainer
|
|
trainer = Trainer(model, device=args.device)
|
|
|
|
# Train model
|
|
print(f"Starting training for {args.epochs} epochs...")
|
|
history = trainer.train(dataloader, epochs=args.epochs, save_dir=args.save_dir)
|
|
|
|
print("Training completed!")
|
|
print(f"Final loss: {history[-1]['loss']:.4f}")
|
|
print(f"Final accuracy: {history[-1]['accuracy']:.2f}%")
|
|
print(f"Models saved to: {args.save_dir}/")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|