package storage import ( "os" "testing" "time" "github.com/jfraeys/fetch_ml/internal/storage" fixtures "github.com/jfraeys/fetch_ml/tests/fixtures" _ "github.com/mattn/go-sqlite3" ) func TestNewDB(t *testing.T) { t.Parallel() // Enable parallel execution // Test creating a new database dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Verify database file was created if _, err := os.Stat(dbPath); os.IsNotExist(err) { t.Error("Database file was not created") } } func TestNewDBInvalidPath(t *testing.T) { t.Parallel() // Enable parallel execution // Test with invalid path invalidPath := "/invalid/path/that/does/not/exist/test.db" _, err := storage.NewDBFromPath(invalidPath) if err == nil { t.Error("Expected error when creating database with invalid path") } } func TestJobOperations(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database with schema schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Test creating a job job := &storage.Job{ ID: "test-job-1", JobName: "test_experiment", Args: "--epochs 10", Status: "pending", Priority: 1, Datasets: []string{"dataset1", "dataset2"}, Metadata: map[string]string{"user": "testuser"}, } err = db.CreateJob(job) if err != nil { t.Fatalf("Failed to create job: %v", err) } // Test retrieving the job retrievedJob, err := db.GetJob("test-job-1") if err != nil { t.Fatalf("Failed to get job: %v", err) } if retrievedJob.ID != job.ID { t.Errorf("Expected job ID %s, got %s", job.ID, retrievedJob.ID) } if retrievedJob.JobName != job.JobName { t.Errorf("Expected job name %s, got %s", job.JobName, retrievedJob.JobName) } if retrievedJob.Status != job.Status { t.Errorf("Expected status %s, got %s", job.Status, retrievedJob.Status) } } func TestUpdateJobStatus(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Create a job job := &storage.Job{ ID: "test-job-2", JobName: "test_experiment", Args: "--epochs 10", Status: "pending", Priority: 1, } err = db.CreateJob(job) if err != nil { t.Fatalf("Failed to create job: %v", err) } // Update job status err = db.UpdateJobStatus("test-job-2", "running", "worker-1", "") if err != nil { t.Fatalf("Failed to update job status: %v", err) } // Verify the update retrievedJob, err := db.GetJob("test-job-2") if err != nil { t.Fatalf("Failed to get updated job: %v", err) } if retrievedJob.Status != "running" { t.Errorf("Expected status 'running', got %s", retrievedJob.Status) } if retrievedJob.WorkerID != "worker-1" { t.Errorf("Expected worker ID 'worker-1', got %s", retrievedJob.WorkerID) } if retrievedJob.StartedAt == nil { t.Error("StartedAt should not be nil after status update") } } func TestListJobs(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Create multiple jobs jobs := []*storage.Job{ {ID: "job-1", JobName: "experiment1", Status: "pending", Priority: 1}, {ID: "job-2", JobName: "experiment2", Status: "running", Priority: 2}, {ID: "job-3", JobName: "experiment3", Status: "completed", Priority: 3}, } for _, job := range jobs { err = db.CreateJob(job) if err != nil { t.Fatalf("Failed to create job %s: %v", job.ID, err) } } // List all jobs allJobs, err := db.ListJobs("", 0) if err != nil { t.Fatalf("Failed to list jobs: %v", err) } if len(allJobs) != 3 { t.Errorf("Expected 3 jobs, got %d", len(allJobs)) } // List jobs by status pendingJobs, err := db.ListJobs("pending", 0) if err != nil { t.Fatalf("Failed to list pending jobs: %v", err) } if len(pendingJobs) != 1 { t.Errorf("Expected 1 pending job, got %d", len(pendingJobs)) } } func TestWorkerOperations(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Create a worker worker := &storage.Worker{ ID: "worker-1", Hostname: "test-host", Status: "active", CurrentJobs: 0, MaxJobs: 5, Metadata: map[string]string{"version": "1.0"}, } err = db.RegisterWorker(worker) if err != nil { t.Fatalf("Failed to create worker: %v", err) } // Get active workers workers, err := db.GetActiveWorkers() if err != nil { t.Fatalf("Failed to get active workers: %v", err) } if len(workers) == 0 { t.Error("Expected at least one active worker") } retrievedWorker := workers[0] if retrievedWorker.ID != worker.ID { t.Errorf("Expected worker ID %s, got %s", worker.ID, retrievedWorker.ID) } if retrievedWorker.Hostname != worker.Hostname { t.Errorf("Expected hostname %s, got %s", worker.Hostname, retrievedWorker.Hostname) } } func TestUpdateWorkerHeartbeat(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Create a worker worker := &storage.Worker{ ID: "worker-1", Hostname: "test-host", Status: "active", CurrentJobs: 0, MaxJobs: 5, } err = db.RegisterWorker(worker) if err != nil { t.Fatalf("Failed to create worker: %v", err) } // Update heartbeat err = db.UpdateWorkerHeartbeat("worker-1") if err != nil { t.Fatalf("Failed to update worker heartbeat: %v", err) } // Verify heartbeat was updated workers, err := db.GetActiveWorkers() if err != nil { t.Fatalf("Failed to get active workers: %v", err) } if len(workers) == 0 { t.Fatal("No active workers found") } // Check that heartbeat was updated (should be recent) if time.Since(workers[0].LastHeartbeat) > time.Second { t.Error("Worker heartbeat was not updated properly") } } func TestJobMetrics(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Create a job job := &storage.Job{ ID: "test-job-metrics", JobName: "test_experiment", Status: "running", Priority: 1, } err = db.CreateJob(job) if err != nil { t.Fatalf("Failed to create job: %v", err) } // Record some metrics err = db.RecordJobMetric("test-job-metrics", "accuracy", "0.95") if err != nil { t.Fatalf("Failed to record job metric: %v", err) } err = db.RecordJobMetric("test-job-metrics", "loss", "0.05") if err != nil { t.Fatalf("Failed to record job metric: %v", err) } // Get metrics metrics, err := db.GetJobMetrics("test-job-metrics") if err != nil { t.Fatalf("Failed to get job metrics: %v", err) } if len(metrics) != 2 { t.Errorf("Expected 2 metrics, got %d", len(metrics)) } if metrics["accuracy"] != "0.95" { t.Errorf("Expected accuracy 0.95, got %s", metrics["accuracy"]) } if metrics["loss"] != "0.05" { t.Errorf("Expected loss 0.05, got %s", metrics["loss"]) } } func TestSystemMetrics(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Record system metrics err = db.RecordSystemMetric("cpu_usage", "75.5") if err != nil { t.Fatalf("Failed to record system metric: %v", err) } err = db.RecordSystemMetric("memory_usage", "2.1GB") if err != nil { t.Fatalf("Failed to record system metric: %v", err) } // Note: There's no GetSystemMetrics method in the current API, // but we can verify the metrics were recorded without errors t.Log("System metrics recorded successfully") } func TestDBConstraints(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test_constraints.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database with schema schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Test duplicate job ID job := &storage.Job{ ID: "duplicate-test", JobName: "test", Status: "pending", } if err := db.CreateJob(job); err != nil { t.Fatalf("Failed to create first job: %v", err) } // Should fail on duplicate if err := db.CreateJob(job); err == nil { t.Error("Expected error when creating duplicate job") } // Test getting non-existent job _, err = db.GetJob("non-existent") if err == nil { t.Error("Expected error when getting non-existent job") } } func TestDBWithDatasetsAndMetadata(t *testing.T) { t.Parallel() // Enable parallel execution dbPath := t.TempDir() + "/test.db" db, err := storage.NewDBFromPath(dbPath) if err != nil { t.Fatalf("Failed to create database: %v", err) } defer func() { _ = db.Close() }() // Initialize database with schema schema := fixtures.TestSchema err = db.Initialize(schema) if err != nil { t.Fatalf("Failed to initialize database: %v", err) } // Test job creation with datasets and metadata job := &storage.Job{ ID: "test-job-full", JobName: "test_experiment", Args: "--epochs 10 --lr 0.001", Status: "pending", Priority: 1, Datasets: []string{"dataset1", "dataset2"}, Metadata: map[string]string{"gpu": "true", "memory": "8GB"}, } if err := db.CreateJob(job); err != nil { t.Fatalf("Failed to create job: %v", err) } // Verify job retrieval with datasets and metadata retrievedJob, err := db.GetJob("test-job-full") if err != nil { t.Fatalf("Failed to get job: %v", err) } if len(retrievedJob.Datasets) != 2 { t.Errorf("Expected 2 datasets, got %d", len(retrievedJob.Datasets)) } if retrievedJob.Metadata["gpu"] != "true" { t.Errorf("Expected gpu=true, got %s", retrievedJob.Metadata["gpu"]) } // Test metrics recording if err := db.RecordJobMetric("test-job-full", "accuracy", "0.95"); err != nil { t.Fatalf("Failed to record job metric: %v", err) } if err := db.RecordSystemMetric("cpu_usage", "75"); err != nil { t.Fatalf("Failed to record system metric: %v", err) } // Test metrics retrieval metrics, err := db.GetJobMetrics("test-job-full") if err != nil { t.Fatalf("Failed to get job metrics: %v", err) } if metrics["accuracy"] != "0.95" { t.Errorf("Expected accuracy 0.95, got %s", metrics["accuracy"]) } // Test job listing jobs, err := db.ListJobs("", 10) if err != nil { t.Fatalf("Failed to list jobs: %v", err) } t.Logf("Found %d jobs", len(jobs)) for i, job := range jobs { t.Logf("Job %d: ID=%s, Status=%s", i, job.ID, job.Status) } if len(jobs) != 1 { t.Errorf("Expected 1 job, got %d", len(jobs)) return } if jobs[0].ID != "test-job-full" { t.Errorf("Expected job ID test-job-full, got %s", jobs[0].ID) return } // Test worker registration with metadata worker := &storage.Worker{ ID: "worker-full", Hostname: "test-host", Status: "active", CurrentJobs: 0, MaxJobs: 2, Metadata: map[string]string{"cpu": "8", "memory": "16GB"}, } if err := db.RegisterWorker(worker); err != nil { t.Fatalf("Failed to register worker: %v", err) } // Test worker heartbeat if err := db.UpdateWorkerHeartbeat("worker-full"); err != nil { t.Fatalf("Failed to update worker heartbeat: %v", err) } // Test active workers workers, err := db.GetActiveWorkers() if err != nil { t.Fatalf("Failed to get active workers: %v", err) } if len(workers) != 1 { t.Errorf("Expected 1 active worker, got %d", len(workers)) } if workers[0].ID != "worker-full" { t.Errorf("Expected worker ID worker-full, got %s", workers[0].ID) } }