package tests import ( "os" "path/filepath" "testing" "github.com/jfraeys/fetch_ml/internal/fileutil" tests "github.com/jfraeys/fetch_ml/tests/fixtures" ) // TestMLProjectVariants tests different types of ML projects with zero-install workflow func TestMLProjectVariants(t *testing.T) { testDir := t.TempDir() // Test 1: Scikit-learn project t.Run("ScikitLearnProject", func(t *testing.T) { tests.CreateMLProject(t, testDir, "sklearn_experiment", tests.ScikitLearnTemplate()) }) // Test 2: XGBoost project t.Run("XGBoostProject", func(t *testing.T) { tests.CreateMLProject(t, testDir, "xgboost_experiment", tests.XGBoostTemplate()) }) // Test 3: PyTorch project (deep learning) t.Run("PyTorchProject", func(t *testing.T) { tests.CreateMLProject(t, testDir, "pytorch_experiment", tests.PyTorchTemplate()) }) // Test 5: Traditional ML (statsmodels) t.Run("StatsModelsProject", func(t *testing.T) { tests.CreateMLProject(t, testDir, "statsmodels_experiment", tests.StatsModelsTemplate()) }) } // TestMLProjectCompatibility tests that all project types work with zero-install workflow func TestMLProjectCompatibility(t *testing.T) { testDir := t.TempDir() // Test that all project types can be uploaded and processed projectTypes := []string{ "sklearn_experiment", "xgboost_experiment", "pytorch_experiment", "tensorflow_experiment", "statsmodels_experiment", } for _, projectType := range projectTypes { t.Run(projectType+"_UploadTest", func(t *testing.T) { // Create experiment directory experimentDir := filepath.Join(testDir, projectType) if err := os.MkdirAll(experimentDir, 0750); err != nil { t.Fatalf("Failed to create experiment directory: %v", err) } // Create minimal files tEntrypoint := filepath.Join(experimentDir, "train.py") trainCode := `#!/usr/bin/env python3 import argparse, json, logging, time from pathlib import Path def main(): parser = argparse.ArgumentParser() parser.add_argument("--output_dir", type=str, required=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info(f"Training {projectType} model...") # Simulate training for epoch in range(3): logger.info(f"Epoch {epoch + 1}: training...") time.sleep(0.01) results = {"model_type": projectType, "status": "completed"} output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / "results.json", "w") as f: json.dump(results, f) logger.info("Training complete!") if __name__ == "__main__": main() ` if err := os.WriteFile(tEntrypoint, []byte(trainCode), 0600); err != nil { t.Fatalf("Failed to create train.py: %v", err) } // Create requirements.txt requirementsFile := filepath.Join(experimentDir, "requirements.txt") requirements := "# Framework-specific dependencies\n" if err := os.WriteFile(requirementsFile, []byte(requirements), 0600); err != nil { t.Fatalf("Failed to create requirements.txt: %v", err) } // Simulate upload process serverDir := filepath.Join(testDir, "server", "home", "mluser", "ml_jobs", "pending") jobDir := filepath.Join(serverDir, projectType+"_20231201_143022") if err := os.MkdirAll(jobDir, 0750); err != nil { t.Fatalf("Failed to create server directories: %v", err) } // Copy files files := []string{"train.py", "requirements.txt"} for _, file := range files { src := filepath.Join(experimentDir, file) dst := filepath.Join(jobDir, file) data, err := fileutil.SecureFileRead(src) if err != nil { t.Fatalf("Failed to read %s: %v", file, err) } if err := os.WriteFile(dst, data, 0600); err != nil { t.Fatalf("Failed to copy %s: %v", file, err) } } // Verify upload for _, file := range files { dst := filepath.Join(jobDir, file) if _, err := os.Stat(dst); os.IsNotExist(err) { t.Errorf("Uploaded file %s should exist for %s", file, projectType) } } }) } }