226 lines
5.7 KiB
Go
226 lines
5.7 KiB
Go
package unit
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
tests "github.com/jfraeys/fetch_ml/tests/fixtures"
|
|
)
|
|
|
|
// TestBasicRedisConnection tests basic Redis connectivity
|
|
func TestBasicRedisConnection(t *testing.T) {
|
|
// Not parallel: uses a shared Redis DB index + FlushDB in defer.
|
|
ctx := context.Background()
|
|
|
|
// Use fixtures for Redis operations
|
|
redisHelper, err := tests.NewRedisHelper("localhost:6379", 12)
|
|
if err != nil {
|
|
t.Skipf("Redis not available, skipping test: %v", err)
|
|
}
|
|
defer func() {
|
|
_ = redisHelper.FlushDB()
|
|
_ = redisHelper.Close()
|
|
}()
|
|
|
|
// Test basic operations
|
|
key := "test:key"
|
|
value := "test_value"
|
|
|
|
// Set
|
|
if setErr := redisHelper.GetClient().Set(ctx, key, value, time.Hour).Err(); setErr != nil {
|
|
t.Fatalf("Failed to set value: %v", setErr)
|
|
}
|
|
|
|
// Get
|
|
result, err := redisHelper.GetClient().Get(ctx, key).Result()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get value: %v", err)
|
|
}
|
|
|
|
if result != value {
|
|
t.Errorf("Expected value '%s', got '%s'", value, result)
|
|
}
|
|
|
|
// Delete
|
|
if delErr := redisHelper.GetClient().Del(ctx, key).Err(); delErr != nil {
|
|
t.Fatalf("Failed to delete key: %v", delErr)
|
|
}
|
|
|
|
// Verify deleted
|
|
_, err = redisHelper.GetClient().Get(ctx, key).Result()
|
|
if err == nil {
|
|
t.Error("Expected error when getting deleted key")
|
|
}
|
|
}
|
|
|
|
// TestTaskQueueBasicOperations tests basic task queue functionality
|
|
func TestTaskQueueBasicOperations(t *testing.T) {
|
|
// Not parallel: uses a shared Redis DB index + FlushDB in defer.
|
|
|
|
// Use fixtures for Redis operations
|
|
redisHelper, err := tests.NewRedisHelper("localhost:6379", 11)
|
|
if err != nil {
|
|
t.Skipf("Redis not available, skipping test: %v", err)
|
|
}
|
|
defer func() {
|
|
_ = redisHelper.FlushDB()
|
|
_ = redisHelper.Close()
|
|
}()
|
|
|
|
// Create task queue
|
|
taskQueue, err := tests.NewTaskQueue(&tests.Config{
|
|
RedisAddr: "localhost:6379",
|
|
RedisDB: 11,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create task queue: %v", err)
|
|
}
|
|
defer func() { _ = taskQueue.Close() }()
|
|
|
|
// Test enqueue
|
|
task, err := taskQueue.EnqueueTask("simple_test", "--epochs 1", 5)
|
|
if err != nil {
|
|
t.Fatalf("Failed to enqueue task: %v", err)
|
|
}
|
|
|
|
if task.ID == "" {
|
|
t.Error("Task ID should not be empty")
|
|
}
|
|
|
|
if task.Status != "queued" {
|
|
t.Errorf("Expected status 'queued', got '%s'", task.Status)
|
|
}
|
|
|
|
// Test get
|
|
retrievedTask, err := taskQueue.GetTask(task.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get task: %v", err)
|
|
}
|
|
|
|
if retrievedTask.ID != task.ID {
|
|
t.Errorf("Expected task ID %s, got %s", task.ID, retrievedTask.ID)
|
|
}
|
|
|
|
// Test get next
|
|
nextTask, err := taskQueue.GetNextTask()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get next task: %v", err)
|
|
}
|
|
|
|
if nextTask == nil {
|
|
t.Fatal("Should have retrieved a task")
|
|
}
|
|
|
|
if nextTask.ID != task.ID {
|
|
t.Errorf("Expected task ID %s, got %s", task.ID, nextTask.ID)
|
|
}
|
|
|
|
// Test update
|
|
now := time.Now()
|
|
nextTask.Status = "running"
|
|
nextTask.StartedAt = &now
|
|
|
|
if updateErr := taskQueue.UpdateTask(nextTask); updateErr != nil {
|
|
t.Fatalf("Failed to update task: %v", updateErr)
|
|
}
|
|
|
|
// Verify update
|
|
updatedTask, err := taskQueue.GetTask(nextTask.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get updated task: %v", err)
|
|
}
|
|
|
|
if updatedTask.Status != "running" {
|
|
t.Errorf("Expected status 'running', got '%s'", updatedTask.Status)
|
|
}
|
|
|
|
if updatedTask.StartedAt == nil {
|
|
t.Error("StartedAt should not be nil")
|
|
}
|
|
|
|
// Test metrics
|
|
if metricErr := taskQueue.RecordMetric("simple_test", "accuracy", 0.95); metricErr != nil {
|
|
t.Fatalf("Failed to record metric: %v", metricErr)
|
|
}
|
|
|
|
metrics, err := taskQueue.GetMetrics("simple_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to get metrics: %v", err)
|
|
}
|
|
|
|
if metrics["accuracy"] != "0.95" {
|
|
t.Errorf("Expected accuracy '0.95', got '%s'", metrics["accuracy"])
|
|
}
|
|
|
|
t.Log("Basic task queue operations test completed successfully")
|
|
}
|
|
|
|
// TestManageScriptHealthCheck tests the manage.sh health check functionality
|
|
func TestManageScriptHealthCheck(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Use fixtures for manage script operations
|
|
manageScript := "../../tools/manage.sh"
|
|
if _, err := os.Stat(manageScript); os.IsNotExist(err) {
|
|
t.Skipf("manage.sh not found at %s", manageScript)
|
|
}
|
|
|
|
ms := tests.NewManageScript(manageScript)
|
|
|
|
// Test help command to verify health command exists
|
|
output, err := ms.Status()
|
|
if err != nil {
|
|
t.Fatalf("Failed to run manage.sh status: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(output, "Redis") {
|
|
t.Error("manage.sh status should include 'Redis' service status")
|
|
}
|
|
|
|
t.Log("manage.sh status command verification completed")
|
|
}
|
|
|
|
// TestAPIHealthEndpoint tests the actual API health endpoint
|
|
func TestAPIHealthEndpoint(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
|
|
// Create HTTP client with reduced timeout for better performance
|
|
client := &http.Client{
|
|
Timeout: 3 * time.Second, // Reduced from 5 seconds
|
|
}
|
|
|
|
// Test the health endpoint
|
|
req, err := http.NewRequestWithContext(context.Background(), "GET", "https://localhost:9101/health", nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create request: %v", err)
|
|
}
|
|
|
|
// Add required headers
|
|
req.Header.Set("X-API-Key", "password")
|
|
req.Header.Set("X-Forwarded-For", "127.0.0.1")
|
|
|
|
// Make request (skip TLS verification for self-signed certs in test)
|
|
client.Transport = &http.Transport{
|
|
//nolint:gosec // G402: TLS InsecureSkipVerify set true - this is a test
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
// API might not be running, which is okay for this test
|
|
t.Skipf("API not available, skipping health endpoint test: %v", err)
|
|
return
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
t.Log("API health endpoint test completed successfully")
|
|
}
|