- Add end-to-end tests for complete workflow validation - Include integration tests for API and database interactions - Add unit tests for all major components and utilities - Include performance tests for payload handling - Add CLI API integration tests - Include Podman container integration tests - Add WebSocket and queue execution tests - Include shell script tests for setup validation Provides comprehensive test coverage ensuring platform reliability and functionality across all components and interactions.
124 lines
3.5 KiB
Python
Executable file
124 lines
3.5 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
import argparse
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
import time
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data import TensorDataset
|
|
|
|
|
|
class SimpleNet(nn.Module):
|
|
def __init__(self, input_size, hidden_size, output_size):
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(input_size, hidden_size)
|
|
self.fc2 = nn.Linear(hidden_size, output_size)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--epochs", type=int, default=10)
|
|
parser.add_argument("--batch_size", type=int, default=32)
|
|
parser.add_argument("--learning_rate", type=float, default=0.001)
|
|
parser.add_argument("--hidden_size", type=int, default=64)
|
|
parser.add_argument("--output_dir", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info(f"Training PyTorch model for {args.epochs} epochs...")
|
|
|
|
# Generate synthetic data
|
|
torch.manual_seed(42)
|
|
X = torch.randn(1000, 20)
|
|
y = torch.randint(0, 2, (1000,))
|
|
|
|
# Create dataset and dataloader
|
|
dataset = TensorDataset(X, y)
|
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
|
|
|
# Initialize model
|
|
model = SimpleNet(20, args.hidden_size, 2)
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
|
|
|
|
# Training loop
|
|
model.train()
|
|
for epoch in range(args.epochs):
|
|
total_loss = 0
|
|
correct = 0
|
|
total = 0
|
|
|
|
for batch_X, batch_y in dataloader:
|
|
optimizer.zero_grad()
|
|
outputs = model(batch_X)
|
|
loss = criterion(outputs, batch_y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += batch_y.size(0)
|
|
correct += (predicted == batch_y).sum().item()
|
|
|
|
accuracy = correct / total
|
|
avg_loss = total_loss / len(dataloader)
|
|
|
|
logger.info(
|
|
f"Epoch {epoch + 1}/{args.epochs}: Loss={avg_loss:.4f}, Acc={accuracy:.4f}"
|
|
)
|
|
time.sleep(0.05) # Reduced delay for faster testing
|
|
|
|
# Final evaluation
|
|
model.eval()
|
|
with torch.no_grad():
|
|
correct = 0
|
|
total = 0
|
|
for batch_X, batch_y in dataloader:
|
|
outputs = model(batch_X)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += batch_y.size(0)
|
|
correct += (predicted == batch_y).sum().item()
|
|
|
|
final_accuracy = correct / total
|
|
|
|
logger.info(f"Training completed. Final accuracy: {final_accuracy:.4f}")
|
|
|
|
# Save results
|
|
results = {
|
|
"model_type": "PyTorch",
|
|
"epochs": args.epochs,
|
|
"batch_size": args.batch_size,
|
|
"learning_rate": args.learning_rate,
|
|
"hidden_size": args.hidden_size,
|
|
"final_accuracy": final_accuracy,
|
|
"n_samples": len(X),
|
|
"input_features": X.shape[1],
|
|
}
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_dir / "results.json", "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
# Save model
|
|
torch.save(model.state_dict(), output_dir / "pytorch_model.pth")
|
|
|
|
logger.info("Results and model saved successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|