- Fix YAML tags in auth config struct (json -> yaml) - Update CLI configs to use pre-hashed API keys - Remove double hashing in WebSocket client - Fix port mapping (9102 -> 9103) in CLI commands - Update permission keys to use jobs:read, jobs:create, etc. - Clean up all debug logging from CLI and server - All user roles now authenticate correctly: * Admin: Can queue jobs and see all jobs * Researcher: Can queue jobs and see own jobs * Analyst: Can see status (read-only access) Multi-user authentication is now fully functional.
137 lines
4 KiB
Go
137 lines
4 KiB
Go
package tests
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
|
tests "github.com/jfraeys/fetch_ml/tests/fixtures"
|
|
)
|
|
|
|
// TestMLProjectVariants tests different types of ML projects with zero-install workflow
|
|
func TestMLProjectVariants(t *testing.T) {
|
|
testDir := t.TempDir()
|
|
|
|
// Test 1: Scikit-learn project
|
|
t.Run("ScikitLearnProject", func(t *testing.T) {
|
|
tests.CreateMLProject(t, testDir, "sklearn_experiment", tests.ScikitLearnTemplate())
|
|
})
|
|
|
|
// Test 2: XGBoost project
|
|
t.Run("XGBoostProject", func(t *testing.T) {
|
|
tests.CreateMLProject(t, testDir, "xgboost_experiment", tests.XGBoostTemplate())
|
|
})
|
|
|
|
// Test 3: PyTorch project (deep learning)
|
|
t.Run("PyTorchProject", func(t *testing.T) {
|
|
tests.CreateMLProject(t, testDir, "pytorch_experiment", tests.PyTorchTemplate())
|
|
})
|
|
|
|
// Test 5: Traditional ML (statsmodels)
|
|
t.Run("StatsModelsProject", func(t *testing.T) {
|
|
tests.CreateMLProject(t, testDir, "statsmodels_experiment", tests.StatsModelsTemplate())
|
|
})
|
|
}
|
|
|
|
// TestMLProjectCompatibility tests that all project types work with zero-install workflow
|
|
func TestMLProjectCompatibility(t *testing.T) {
|
|
testDir := t.TempDir()
|
|
|
|
// Test that all project types can be uploaded and processed
|
|
projectTypes := []string{
|
|
"sklearn_experiment",
|
|
"xgboost_experiment",
|
|
"pytorch_experiment",
|
|
"tensorflow_experiment",
|
|
"statsmodels_experiment",
|
|
}
|
|
|
|
for _, projectType := range projectTypes {
|
|
t.Run(projectType+"_UploadTest", func(t *testing.T) {
|
|
// Create experiment directory
|
|
experimentDir := filepath.Join(testDir, projectType)
|
|
if err := os.MkdirAll(experimentDir, 0750); err != nil {
|
|
t.Fatalf("Failed to create experiment directory: %v", err)
|
|
}
|
|
|
|
// Create minimal files
|
|
trainScript := filepath.Join(experimentDir, "train.py")
|
|
trainCode := `#!/usr/bin/env python3
|
|
import argparse, json, logging, time
|
|
from pathlib import Path
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--output_dir", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info(f"Training {projectType} model...")
|
|
|
|
# Simulate training
|
|
for epoch in range(3):
|
|
logger.info(f"Epoch {epoch + 1}: training...")
|
|
time.sleep(0.01)
|
|
|
|
results = {"model_type": projectType, "status": "completed"}
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_dir / "results.json", "w") as f:
|
|
json.dump(results, f)
|
|
|
|
logger.info("Training complete!")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
`
|
|
|
|
if err := os.WriteFile(trainScript, []byte(trainCode), 0600); err != nil {
|
|
t.Fatalf("Failed to create train.py: %v", err)
|
|
}
|
|
|
|
// Create requirements.txt
|
|
requirementsFile := filepath.Join(experimentDir, "requirements.txt")
|
|
requirements := "# Framework-specific dependencies\n"
|
|
if err := os.WriteFile(requirementsFile, []byte(requirements), 0600); err != nil {
|
|
t.Fatalf("Failed to create requirements.txt: %v", err)
|
|
}
|
|
|
|
// Simulate upload process
|
|
serverDir := filepath.Join(testDir, "server", "home", "mluser", "ml_jobs", "pending")
|
|
jobDir := filepath.Join(serverDir, projectType+"_20231201_143022")
|
|
|
|
if err := os.MkdirAll(jobDir, 0750); err != nil {
|
|
t.Fatalf("Failed to create server directories: %v", err)
|
|
}
|
|
|
|
// Copy files
|
|
files := []string{"train.py", "requirements.txt"}
|
|
for _, file := range files {
|
|
src := filepath.Join(experimentDir, file)
|
|
dst := filepath.Join(jobDir, file)
|
|
|
|
data, err := fileutil.SecureFileRead(src)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read %s: %v", file, err)
|
|
}
|
|
|
|
if err := os.WriteFile(dst, data, 0600); err != nil {
|
|
t.Fatalf("Failed to copy %s: %v", file, err)
|
|
}
|
|
}
|
|
|
|
// Verify upload
|
|
for _, file := range files {
|
|
dst := filepath.Join(jobDir, file)
|
|
if _, err := os.Stat(dst); os.IsNotExist(err) {
|
|
t.Errorf("Uploaded file %s should exist for %s", file, projectType)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|