From 7305e2bc2147f595700c45de9ea774d6278c05dd Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Mon, 16 Feb 2026 20:38:15 -0500 Subject: [PATCH] test: add comprehensive test coverage and command improvements - Add logs and debug end-to-end tests - Add test helper utilities - Improve test fixtures and templates - Update API server and config lint commands - Add multi-user database initialization --- .../benchmarks/artifact_scanner_bench_test.go | 125 ++++ tests/benchmarks/config_parsing_bench_test.go | 175 ++++++ .../json_serialization_bench_test.go | 185 ++++++ .../benchmarks/jupyter_service_bench_test.go | 278 +++++++++ tests/benchmarks/log_sanitize_bench_test.go | 84 +++ tests/benchmarks/native_queue_basic_test.go | 40 ++ tests/benchmarks/native_queue_bench_test.go | 108 ++++ tests/benchmarks/streaming_io_bench_test.go | 189 ++++++ tests/chaos/chaos_test.go | 10 + tests/e2e/docker-compose.logs-debug.yml | 50 ++ tests/e2e/logs_debug_e2e_test.go | 590 ++++++++++++++++++ tests/integration/duplicate_detection_test.go | 63 ++ .../jupyter_experiment_test.go} | 0 .../protocol_test.go} | 0 .../websocket_queue_integration_test.go | 8 +- .../ws_handler_integration_test.go | 6 +- .../api/duplicate_detection_process_test.go | 95 +++ tests/unit/api/helpers/db_helpers_test.go | 225 +++++++ tests/unit/api/helpers/hash_helpers_test.go | 137 ++++ tests/unit/api/helpers/payload_parser_test.go | 451 +++++++++++++ .../unit/api/helpers/response_helpers_test.go | 345 ++++++++++ .../api/helpers/validation_helpers_test.go | 486 +++++++++++++++ 22 files changed, 3643 insertions(+), 7 deletions(-) create mode 100644 tests/benchmarks/artifact_scanner_bench_test.go create mode 100644 tests/benchmarks/config_parsing_bench_test.go create mode 100644 tests/benchmarks/json_serialization_bench_test.go create mode 100644 tests/benchmarks/jupyter_service_bench_test.go create mode 100644 tests/benchmarks/log_sanitize_bench_test.go create mode 100644 tests/benchmarks/native_queue_basic_test.go create mode 100644 tests/benchmarks/native_queue_bench_test.go create mode 100644 tests/benchmarks/streaming_io_bench_test.go create mode 100644 tests/e2e/docker-compose.logs-debug.yml create mode 100644 tests/e2e/logs_debug_e2e_test.go create mode 100644 tests/integration/duplicate_detection_test.go rename tests/{jupyter_experiment_integration_test.go => integration/jupyter_experiment_test.go} (100%) rename tests/{integration_protocol_test.go => integration/protocol_test.go} (100%) create mode 100644 tests/unit/api/duplicate_detection_process_test.go create mode 100644 tests/unit/api/helpers/db_helpers_test.go create mode 100644 tests/unit/api/helpers/hash_helpers_test.go create mode 100644 tests/unit/api/helpers/payload_parser_test.go create mode 100644 tests/unit/api/helpers/response_helpers_test.go create mode 100644 tests/unit/api/helpers/validation_helpers_test.go diff --git a/tests/benchmarks/artifact_scanner_bench_test.go b/tests/benchmarks/artifact_scanner_bench_test.go new file mode 100644 index 0000000..6d5acb4 --- /dev/null +++ b/tests/benchmarks/artifact_scanner_bench_test.go @@ -0,0 +1,125 @@ +package benchmarks + +import ( + "os" + "path/filepath" + "testing" + + "github.com/jfraeys/fetch_ml/internal/worker" +) + +// BenchmarkArtifactScanGo profiles Go filepath.WalkDir implementation +func BenchmarkArtifactScanGo(b *testing.B) { + tmpDir := b.TempDir() + + // Create test artifact structure + createTestArtifacts(b, tmpDir, 100) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, err := worker.ScanArtifacts(tmpDir) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkArtifactScanNative profiles C++ platform-optimized traversal +// Uses: fts on BSD, getdents64 on Linux, getattrlistbulk on macOS +func BenchmarkArtifactScanNative(b *testing.B) { + tmpDir := b.TempDir() + + // Create test artifact structure + createTestArtifacts(b, tmpDir, 100) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, err := worker.ScanArtifactsNative(tmpDir) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkArtifactScanLarge tests with many files +func BenchmarkArtifactScanLarge(b *testing.B) { + tmpDir := b.TempDir() + + // Create 1000 test files + createTestArtifacts(b, tmpDir, 1000) + + b.Run("Go", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := worker.ScanArtifacts(tmpDir) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("Native", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := worker.ScanArtifactsNative(tmpDir) + if err != nil { + b.Fatal(err) + } + } + }) +} + +// createTestArtifacts creates a directory structure with test files +func createTestArtifacts(b testing.TB, root string, count int) { + b.Helper() + + // Create nested directories + dirs := []string{ + "outputs", + "outputs/models", + "outputs/checkpoints", + "logs", + "data", + } + + for _, dir := range dirs { + if err := os.MkdirAll(filepath.Join(root, dir), 0750); err != nil { + b.Fatal(err) + } + } + + // Create test files + for i := 0; i < count; i++ { + var path string + switch i % 5 { + case 0: + path = filepath.Join(root, "outputs", "model_"+string(rune('0'+i%10))+".pt") + case 1: + path = filepath.Join(root, "outputs", "models", "checkpoint_"+string(rune('0'+i%10))+".ckpt") + case 2: + path = filepath.Join(root, "outputs", "checkpoints", "epoch_"+string(rune('0'+i%10))+".pt") + case 3: + path = filepath.Join(root, "logs", "train_"+string(rune('0'+i%10))+".log") + case 4: + path = filepath.Join(root, "data", "batch_"+string(rune('0'+i%10))+".npy") + } + + data := make([]byte, 1024*(i%10+1)) // Varying sizes 1KB-10KB + for j := range data { + data[j] = byte(i + j%256) + } + + if err := os.WriteFile(path, data, 0640); err != nil { + b.Fatal(err) + } + } + + // Create files that should be excluded + os.WriteFile(filepath.Join(root, "run_manifest.json"), []byte("{}"), 0640) + os.MkdirAll(filepath.Join(root, "code"), 0750) + os.WriteFile(filepath.Join(root, "code", "script.py"), []byte("# test"), 0640) +} diff --git a/tests/benchmarks/config_parsing_bench_test.go b/tests/benchmarks/config_parsing_bench_test.go new file mode 100644 index 0000000..e5ad50a --- /dev/null +++ b/tests/benchmarks/config_parsing_bench_test.go @@ -0,0 +1,175 @@ +package benchmarks + +import ( + "testing" + + "gopkg.in/yaml.v3" +) + +// Sample server config YAML for benchmarking +const sampleServerConfig = ` +base_path: /api/v1 +data_dir: /data/fetchml +auth: + type: jwt + secret: "super-secret-key-for-benchmarking-only" + token_ttl: 3600 +server: + address: :8080 + tls: + enabled: true + cert_file: /etc/ssl/certs/server.crt + key_file: /etc/ssl/private/server.key +security: + max_request_size: 1048576 + rate_limit: 100 + cors_origins: + - https://app.fetchml.com + - https://admin.fetchml.com +queue: + backend: redis + sqlite_path: /data/queue.db + filesystem_path: /data/queue + fallback_to_filesystem: true +redis: + addr: localhost:6379 + password: "" + db: 0 + pool_size: 50 +database: + driver: postgres + dsn: postgres://user:pass@localhost/fetchml?sslmode=disable + max_connections: 100 +logging: + level: info + format: json + output: stdout +resources: + max_cpu_per_task: 8 + max_memory_per_task: 32 + max_gpu_per_task: 1 +monitoring: + enabled: true + prometheus_port: 9090 +` + +type BenchmarkServerConfig struct { + BasePath string `yaml:"base_path"` + DataDir string `yaml:"data_dir"` + Auth BenchmarkAuthConfig `yaml:"auth"` + Server BenchmarkServerSection `yaml:"server"` + Security BenchmarkSecurityConfig `yaml:"security"` + Queue BenchmarkQueueConfig `yaml:"queue"` + Redis BenchmarkRedisConfig `yaml:"redis"` + Database BenchmarkDatabaseConfig `yaml:"database"` + Logging BenchmarkLoggingConfig `yaml:"logging"` + Resources BenchmarkResourceConfig `yaml:"resources"` + Monitoring BenchmarkMonitoringConfig `yaml:"monitoring"` +} + +type BenchmarkAuthConfig struct { + Type string `yaml:"type"` + Secret string `yaml:"secret"` + TokenTTL int `yaml:"token_ttl"` +} + +type BenchmarkServerSection struct { + Address string `yaml:"address"` + TLS BenchmarkTLSConfig `yaml:"tls"` +} + +type BenchmarkTLSConfig struct { + Enabled bool `yaml:"enabled"` + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` +} + +type BenchmarkSecurityConfig struct { + MaxRequestSize int `yaml:"max_request_size"` + RateLimit int `yaml:"rate_limit"` + CORSOrigins []string `yaml:"cors_origins"` +} + +type BenchmarkQueueConfig struct { + Backend string `yaml:"backend"` + SQLitePath string `yaml:"sqlite_path"` + FilesystemPath string `yaml:"filesystem_path"` + FallbackToFilesystem bool `yaml:"fallback_to_filesystem"` +} + +type BenchmarkRedisConfig struct { + Addr string `yaml:"addr"` + Password string `yaml:"password"` + DB int `yaml:"db"` + PoolSize int `yaml:"pool_size"` +} + +type BenchmarkDatabaseConfig struct { + Driver string `yaml:"driver"` + DSN string `yaml:"dsn"` + MaxConnections int `yaml:"max_connections"` +} + +type BenchmarkLoggingConfig struct { + Level string `yaml:"level"` + Format string `yaml:"format"` + Output string `yaml:"output"` +} + +type BenchmarkResourceConfig struct { + MaxCPUPerTask int `yaml:"max_cpu_per_task"` + MaxMemoryPerTask int `yaml:"max_memory_per_task"` + MaxGPUPerTask int `yaml:"max_gpu_per_task"` +} + +type BenchmarkMonitoringConfig struct { + Enabled bool `yaml:"enabled"` + PrometheusPort int `yaml:"prometheus_port"` +} + +// BenchmarkConfigYAMLUnmarshal profiles YAML config parsing +// Tier 3 C++ candidate: Fast binary config format +func BenchmarkConfigYAMLUnmarshal(b *testing.B) { + data := []byte(sampleServerConfig) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + var cfg BenchmarkServerConfig + err := yaml.Unmarshal(data, &cfg) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkConfigYAMLUnmarshalLarge profiles large config parsing +func BenchmarkConfigYAMLUnmarshalLarge(b *testing.B) { + // Create larger config with more nested data + largeConfig := sampleServerConfig + ` + extra_section: + items: + - id: item1 + name: "First Item" + value: 100 + - id: item2 + name: "Second Item" + value: 200 + - id: item3 + name: "Third Item" + value: 300 +` + data := []byte(largeConfig) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + var cfg BenchmarkServerConfig + err := yaml.Unmarshal(data, &cfg) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/tests/benchmarks/json_serialization_bench_test.go b/tests/benchmarks/json_serialization_bench_test.go new file mode 100644 index 0000000..2d49917 --- /dev/null +++ b/tests/benchmarks/json_serialization_bench_test.go @@ -0,0 +1,185 @@ +package benchmarks + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// BenchmarkTaskJSONMarshal profiles Task serialization to JSON +// Tier 3 C++ candidate: Zero-copy binary serialization +func BenchmarkTaskJSONMarshal(b *testing.B) { + task := createTestTask() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, err := json.Marshal(task) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkTaskJSONUnmarshal profiles Task deserialization from JSON +// Tier 3 C++ candidate: Memory-mapped binary format +func BenchmarkTaskJSONUnmarshal(b *testing.B) { + task := createTestTask() + data, err := json.Marshal(task) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + var t queue.Task + err := json.Unmarshal(data, &t) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkTaskJSONRoundTrip profiles full serialize/deserialize cycle +func BenchmarkTaskJSONRoundTrip(b *testing.B) { + task := createTestTask() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + data, err := json.Marshal(task) + if err != nil { + b.Fatal(err) + } + var t queue.Task + err = json.Unmarshal(data, &t) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkPrewarmStateJSONMarshal profiles PrewarmState serialization +// Used in worker prewarm coordination +func BenchmarkPrewarmStateJSONMarshal(b *testing.B) { + state := queue.PrewarmState{ + WorkerID: "worker-12345", + TaskID: "task-67890", + SnapshotID: "snap-abc123", + StartedAt: time.Now().UTC().Format(time.RFC3339), + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + Phase: "running", + EnvImage: "fetchml/python:3.11", + DatasetCnt: 3, + EnvHit: 5, + EnvMiss: 1, + EnvBuilt: 2, + EnvTimeNs: 1500000000, + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, err := json.Marshal(state) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkPrewarmStateJSONUnmarshal profiles PrewarmState deserialization +func BenchmarkPrewarmStateJSONUnmarshal(b *testing.B) { + state := queue.PrewarmState{ + WorkerID: "worker-12345", + TaskID: "task-67890", + SnapshotID: "snap-abc123", + StartedAt: time.Now().UTC().Format(time.RFC3339), + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + Phase: "running", + EnvImage: "fetchml/python:3.11", + DatasetCnt: 3, + EnvHit: 5, + EnvMiss: 1, + EnvBuilt: 2, + EnvTimeNs: 1500000000, + } + data, err := json.Marshal(state) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + var s queue.PrewarmState + err := json.Unmarshal(data, &s) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkTaskBatchJSONMarshal profiles batch task serialization +// Critical for queue index operations with many tasks +func BenchmarkTaskBatchJSONMarshal(b *testing.B) { + tasks := make([]queue.Task, 100) + for i := 0; i < 100; i++ { + tasks[i] = createTestTaskWithID(i) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, task := range tasks { + _, err := json.Marshal(task) + if err != nil { + b.Fatal(err) + } + } + } +} + +// Helper to create a test task +func createTestTask() queue.Task { + return createTestTaskWithID(0) +} + +func createTestTaskWithID(id int) queue.Task { + now := time.Now() + return queue.Task{ + ID: fmt.Sprintf("task-550e8400-e29b-41d4-a716-44665544%03d", id), + JobName: "ml-training-job", + Args: "--epochs=100 --batch_size=32 --learning_rate=0.001", + Status: "queued", + Priority: 1000, + CreatedAt: now, + DatasetSpecs: []queue.DatasetSpec{ + {Name: "training-data", Version: "v1.2.3", Checksum: "sha256:abc123", URI: "s3://bucket/data/train"}, + {Name: "validation-data", Version: "v1.0.0", Checksum: "sha256:def456", URI: "s3://bucket/data/val"}, + }, + Datasets: []string{"dataset1", "dataset2"}, + Metadata: map[string]string{ + "experiment_id": "exp-12345", + "user_email": "user@example.com", + "model_type": "transformer", + }, + CPU: 4, + MemoryGB: 16, + GPU: 1, + GPUMemory: "24GB", + UserID: "user-550e8400-e29b-41d4-a716", + Username: "ml_researcher", + CreatedBy: "admin", + MaxRetries: 3, + } +} diff --git a/tests/benchmarks/jupyter_service_bench_test.go b/tests/benchmarks/jupyter_service_bench_test.go new file mode 100644 index 0000000..39d9aaa --- /dev/null +++ b/tests/benchmarks/jupyter_service_bench_test.go @@ -0,0 +1,278 @@ +package benchmarks + +import ( + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" +) + +// BenchmarkResolveWorkspacePath profiles path canonicalization hot path +// Tier 2 C++ candidate: SIMD string operations for path validation +func BenchmarkResolveWorkspacePath(b *testing.B) { + testPaths := []string{ + "my-workspace", + "./relative/path", + "/absolute/path/to/workspace", + " trimmed-path ", + "deep/nested/workspace/name", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + path := testPaths[i%len(testPaths)] + // Simulate the resolveWorkspacePath logic inline + ws := strings.TrimSpace(path) + clean := filepath.Clean(ws) + if !filepath.IsAbs(clean) { + clean = filepath.Join("/data/active/workspaces", clean) + } + _ = clean + } +} + +// BenchmarkStringTrimSpace profiles strings.TrimSpace usage +// Heavy usage in service_manager.go (55+ calls) +func BenchmarkStringTrimSpace(b *testing.B) { + testInputs := []string{ + " package-name ", + "\t\n trimmed \r\n", + "normal", + " ", + "", + "pkg1==1.0.0", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = strings.TrimSpace(testInputs[i%len(testInputs)]) + } +} + +// BenchmarkStringSplit profiles strings.Split for package parsing +// Used in parsePipList, parseCondaList, splitPackageList +func BenchmarkStringSplit(b *testing.B) { + testInputs := []string{ + "numpy,pandas,scikit-learn,tensorflow", + "package==1.0.0", + "single", + "", + "a,b,c,d,e,f,g,h,i,j", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = strings.Split(testInputs[i%len(testInputs)], ",") + } +} + +// BenchmarkStringHasPrefix profiles strings.HasPrefix +// Used for path traversal detection and comment filtering +func BenchmarkStringHasPrefix(b *testing.B) { + testInputs := []string{ + "../traversal", + "normal/path", + "# comment", + "package==1.0", + "..", + } + prefixes := []string{"..", "#", "package"} + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = strings.HasPrefix(testInputs[i%len(testInputs)], prefixes[i%len(prefixes)]) + } +} + +// BenchmarkStringEqualFold profiles case-insensitive comparison +// Used for blocked package matching +func BenchmarkStringEqualFold(b *testing.B) { + testPairs := []struct { + a, b string + }{ + {"numpy", "NUMPY"}, + {"requests", "requests"}, + {"TensorFlow", "tensorflow"}, + {"scikit-learn", "SCIKIT-LEARN"}, + {"off", "OFF"}, + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pair := testPairs[i%len(testPairs)] + _ = strings.EqualFold(pair.a, pair.b) + } +} + +// BenchmarkFilepathClean profiles path canonicalization +// Used in resolveWorkspacePath and other path operations +func BenchmarkFilepathClean(b *testing.B) { + testPaths := []string{ + "./relative/../path", + "/absolute//double//slashes", + "../../../etc/passwd", + "normal/path", + ".", + "..", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = filepath.Clean(testPaths[i%len(testPaths)]) + } +} + +// BenchmarkFilepathJoin profiles path joining +// Used extensively for building workspace, trash, state paths +func BenchmarkFilepathJoin(b *testing.B) { + components := [][]string{ + {"/data", "active", "workspaces"}, + {"/tmp", "trash", "jupyter"}, + {"state", "fetch_ml_jupyter_workspaces.json"}, + {"workspace", "subdir", "file.txt"}, + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + parts := components[i%len(components)] + _ = filepath.Join(parts...) + } +} + +// BenchmarkStrconvAtoi profiles string to int conversion +// Used for port parsing, resource limits +func BenchmarkStrconvAtoi(b *testing.B) { + testInputs := []string{ + "8888", + "0", + "65535", + "-1", + "999999", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = strconv.Atoi(testInputs[i%len(testInputs)]) + } +} + +// BenchmarkTimeFormat profiles timestamp formatting +// Used for trash destination naming +func BenchmarkTimeFormat(b *testing.B) { + now := time.Now().UTC() + format := "20060102_150405" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = now.Format(format) + } +} + +// BenchmarkSprintfConcat profiles string concatenation +// Used for building destination names +func BenchmarkSprintfConcat(b *testing.B) { + names := []string{"workspace1", "my-project", "test-ws"} + timestamps := []string{"20240115_143022", "20240212_183045", "20240301_120000"} + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = names[i%len(names)] + "_" + timestamps[i%len(timestamps)] + } +} + +// BenchmarkPackageListParsing profiles full package list parsing pipeline +// Combines Split, TrimSpace, and filtering - mirrors splitPackageList +func BenchmarkPackageListParsing(b *testing.B) { + testInputs := []string{ + "numpy, pandas, scikit-learn, tensorflow", + " requests , urllib3 , certifi ", + "single-package", + "", + "a, b, c, d, e, f, g", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + input := testInputs[i%len(testInputs)] + parts := strings.Split(input, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + } +} + +// BenchmarkEnvLookup profiles environment variable lookups +// Used for FETCHML_JUPYTER_* configuration +func BenchmarkEnvLookup(b *testing.B) { + // Pre-set test env vars + os.Setenv("TEST_JUPYTER_STATE_DIR", "/tmp/jupyter-state") + os.Setenv("TEST_JUPYTER_WORKSPACE_BASE", "/tmp/workspaces") + + keys := []string{ + "TEST_JUPYTER_STATE_DIR", + "TEST_JUPYTER_WORKSPACE_BASE", + "NONEXISTENT_VAR", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = os.LookupEnv(keys[i%len(keys)]) + } +} + +// BenchmarkCombinedJupyterHotPath profiles typical service manager operation +// Combines multiple string/path operations as they occur in real usage +func BenchmarkCombinedJupyterHotPath(b *testing.B) { + testWorkspaces := []string{ + " my-project ", + "./relative-ws", + "/absolute/path/to/workspace", + "deep/nested/name", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Simulate resolveWorkspacePath + path building + ws := strings.TrimSpace(testWorkspaces[i%len(testWorkspaces)]) + clean := filepath.Clean(ws) + if !filepath.IsAbs(clean) { + clean = filepath.Join("/data/active/workspaces", clean) + } + // Simulate trash path building + ts := time.Now().UTC().Format("20060102_150405") + destName := ws + "_" + ts + _ = filepath.Join("/tmp/trash/jupyter", destName) + } +} diff --git a/tests/benchmarks/log_sanitize_bench_test.go b/tests/benchmarks/log_sanitize_bench_test.go new file mode 100644 index 0000000..1a6e24d --- /dev/null +++ b/tests/benchmarks/log_sanitize_bench_test.go @@ -0,0 +1,84 @@ +package benchmarks + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// BenchmarkLogSanitizeMessage profiles log message sanitization. +// This is a Tier 1 C++ candidate because: +// - Regex matching is CPU-intensive +// - High volume log pipelines process thousands of messages/sec +// - C++ can use Hyperscan/RE2 for parallel regex matching +// Expected speedup: 3-5x for high-volume logging +func BenchmarkLogSanitizeMessage(b *testing.B) { + // Test messages with various sensitive data patterns + messages := []string{ + "User login successful with api_key=abc123def45678901234567890abcdef", + "JWT token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0C12LVE", + "Redis connection: redis://:secretpassword123@localhost:6379/0", + "User admin password=supersecret123 trying to access resource", + "Normal log message without any sensitive data to process", + "API call with key=fedcba9876543210fedcba9876543210 and secret=shh123", + "Connection string: redis://:another_secret@redis.example.com:6380", + "Authentication token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.test.signature", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + msg := messages[i%len(messages)] + _ = logging.SanitizeLogMessage(msg) + } +} + +// BenchmarkLogSanitizeArgs profiles structured log argument sanitization. +// This processes key-value pairs looking for sensitive field names. +func BenchmarkLogSanitizeArgs(b *testing.B) { + // Simulate typical structured log arguments + args := []any{ + "user_id", "user123", + "password", "secret123", + "api_key", "abcdef1234567890", + "action", "login", + "secret_token", "eyJhbGci...", + "request_id", "req-12345", + "database_url", "redis://:password@localhost:6379", + "timestamp", "2024-01-01T00:00:00Z", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = logging.SanitizeArgs(args) + } +} + +// BenchmarkLogSanitizeHighVolume simulates high-throughput logging scenario +// with many messages per second (e.g., 10K+ messages/sec). +func BenchmarkLogSanitizeHighVolume(b *testing.B) { + // Mix of message types + testMessages := []string{ + "API request: POST /api/v1/jobs with api_key=abcdef1234567890abcdef1234567890", + "User user123 authenticated with token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature", + "Database connection established to redis://:redact_me@localhost:6379", + "Job execution started for job_name=test_job", + "Error processing request: password=wrong_secret provided", + "Metrics: cpu=45%, memory=2.5GB, gpu=0%", + "Config loaded: secret_key=hidden_value123", + "Webhook received with authorization=Bearer eyJ0eXAiOiJKV1Qi.test.sig", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Simulate batch processing of multiple messages + for _, msg := range testMessages { + _ = logging.SanitizeLogMessage(msg) + } + } +} diff --git a/tests/benchmarks/native_queue_basic_test.go b/tests/benchmarks/native_queue_basic_test.go new file mode 100644 index 0000000..63946e1 --- /dev/null +++ b/tests/benchmarks/native_queue_basic_test.go @@ -0,0 +1,40 @@ +package benchmarks + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// BenchmarkNativeQueueBasic tests basic native queue operations +func BenchmarkNativeQueueBasic(b *testing.B) { + tmpDir := b.TempDir() + + // Only run if native libs available + if !queue.UseNativeQueue { + b.Skip("Native queue not enabled (set FETCHML_NATIVE_LIBS=1)") + } + + q, err := queue.NewNativeQueue(tmpDir) + if err != nil { + b.Fatal(err) + } + defer q.Close() + + // Test single add + task := &queue.Task{ + ID: "test-1", + JobName: "test-job", + Priority: 100, + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + task.ID = "test-" + string(rune('0'+i%10)) + if err := q.AddTask(task); err != nil { + b.Fatal(err) + } + } +} diff --git a/tests/benchmarks/native_queue_bench_test.go b/tests/benchmarks/native_queue_bench_test.go new file mode 100644 index 0000000..fa0aeae --- /dev/null +++ b/tests/benchmarks/native_queue_bench_test.go @@ -0,0 +1,108 @@ +package benchmarks + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// BenchmarkNativeQueueRebuildIndex profiles the native binary queue index. +// Tier 1 C++ candidate: binary format vs JSON +// Expected: 5x speedup, 99% allocation reduction +func BenchmarkNativeQueueRebuildIndex(b *testing.B) { + tmpDir := b.TempDir() + q, err := queue.NewNativeQueue(tmpDir) + if err != nil { + b.Fatal(err) + } + defer q.Close() + + // Seed with tasks + for i := 0; i < 100; i++ { + task := &queue.Task{ + ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)), + JobName: "job-" + string(rune('0'+i/10)), + Priority: int64(100 - i), + } + if err := q.AddTask(task); err != nil { + b.Fatal(err) + } + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark just the add (native uses binary index, no JSON rebuild) + for i := 0; i < b.N; i++ { + task := &queue.Task{ + ID: "bench-task-" + string(rune('0'+i%10)), + JobName: "bench-job", + Priority: int64(i), + } + if err := q.AddTask(task); err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkNativeQueueClaimNext profiles task claiming from binary heap +func BenchmarkNativeQueueClaimNext(b *testing.B) { + tmpDir := b.TempDir() + q, err := queue.NewNativeQueue(tmpDir) + if err != nil { + b.Fatal(err) + } + defer q.Close() + + // Seed with tasks + for i := 0; i < 100; i++ { + task := &queue.Task{ + ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)), + JobName: "job-" + string(rune('0'+i/10)), + Priority: int64(100 - i), + } + if err := q.AddTask(task); err != nil { + b.Fatal(err) + } + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Binary heap pop - no JSON parsing + _, _ = q.PeekNextTask() + } +} + +// BenchmarkNativeQueueGetAllTasks profiles full task scan from binary index +func BenchmarkNativeQueueGetAllTasks(b *testing.B) { + tmpDir := b.TempDir() + q, err := queue.NewNativeQueue(tmpDir) + if err != nil { + b.Fatal(err) + } + defer q.Close() + + // Seed with tasks + for i := 0; i < 100; i++ { + task := &queue.Task{ + ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)), + JobName: "job-" + string(rune('0'+i/10)), + Priority: int64(100 - i), + } + if err := q.AddTask(task); err != nil { + b.Fatal(err) + } + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, err := q.GetAllTasks() + if err != nil { + b.Fatal(err) + } + } +} diff --git a/tests/benchmarks/streaming_io_bench_test.go b/tests/benchmarks/streaming_io_bench_test.go new file mode 100644 index 0000000..cc30359 --- /dev/null +++ b/tests/benchmarks/streaming_io_bench_test.go @@ -0,0 +1,189 @@ +package benchmarks + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "os" + "path/filepath" + "testing" + + "github.com/jfraeys/fetch_ml/internal/worker" +) + +// BenchmarkExtractTarGzGo profiles Go sequential tar.gz extraction +func BenchmarkExtractTarGzGo(b *testing.B) { + tmpDir := b.TempDir() + archivePath := createStreamingTestArchive(b, tmpDir, 100, 1024) // 100 files, 1KB each + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + dstDir := filepath.Join(tmpDir, "extract_go_"+string(rune('0'+i%10))) + if err := os.MkdirAll(dstDir, 0750); err != nil { + b.Fatal(err) + } + if err := worker.ExtractTarGz(archivePath, dstDir); err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkExtractTarGzNative profiles C++ parallel decompression +// Uses: mmap + thread pool + O_DIRECT for large files +func BenchmarkExtractTarGzNative(b *testing.B) { + tmpDir := b.TempDir() + archivePath := createStreamingTestArchive(b, tmpDir, 100, 1024) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + dstDir := filepath.Join(tmpDir, "extract_native_"+string(rune('0'+i%10))) + if err := os.MkdirAll(dstDir, 0750); err != nil { + b.Fatal(err) + } + if err := worker.ExtractTarGzNative(archivePath, dstDir); err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkExtractTarGzSizes tests different archive sizes +func BenchmarkExtractTarGzSizes(b *testing.B) { + tmpDir := b.TempDir() + + // Small: 10 files, 1KB each (~10KB compressed) + b.Run("Small", func(b *testing.B) { + archivePath := createStreamingTestArchive(b, tmpDir, 10, 1024) + benchmarkBoth(b, archivePath, tmpDir) + }) + + // Medium: 100 files, 10KB each (~1MB compressed) + b.Run("Medium", func(b *testing.B) { + archivePath := createStreamingTestArchive(b, tmpDir, 100, 10240) + benchmarkBoth(b, archivePath, tmpDir) + }) + + // Large: 50 files, 100KB each (~5MB compressed) + b.Run("Large", func(b *testing.B) { + archivePath := createStreamingTestArchive(b, tmpDir, 50, 102400) + benchmarkBoth(b, archivePath, tmpDir) + }) +} + +func benchmarkBoth(b *testing.B, archivePath, tmpDir string) { + b.Run("Go", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + dstDir := filepath.Join(tmpDir, "go_"+string(rune('0'+i%10))) + if err := os.MkdirAll(dstDir, 0750); err != nil { + b.Fatal(err) + } + if err := worker.ExtractTarGz(archivePath, dstDir); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("Native", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + dstDir := filepath.Join(tmpDir, "native_"+string(rune('0'+i%10))) + if err := os.MkdirAll(dstDir, 0750); err != nil { + b.Fatal(err) + } + if err := worker.ExtractTarGzNative(archivePath, dstDir); err != nil { + b.Fatal(err) + } + } + }) +} + +// createStreamingTestArchive creates a tar.gz archive with test files for streaming benchmarks +func createStreamingTestArchive(b testing.TB, tmpDir string, numFiles, fileSize int) string { + b.Helper() + + // Create data directory with test files + dataDir := filepath.Join(tmpDir, "data_"+string(rune('0'+numFiles/10))) + if err := os.MkdirAll(dataDir, 0750); err != nil { + b.Fatal(err) + } + + for i := 0; i < numFiles; i++ { + subdir := filepath.Join(dataDir, "subdir_"+string(rune('0'+i%5))) + if err := os.MkdirAll(subdir, 0750); err != nil { + b.Fatal(err) + } + + filename := filepath.Join(subdir, "file_"+string(rune('0'+i/10))+".bin") + data := make([]byte, fileSize) + for j := range data { + data[j] = byte((i + j) % 256) + } + if err := os.WriteFile(filename, data, 0640); err != nil { + b.Fatal(err) + } + } + + // Create tar.gz archive + archivePath := filepath.Join(tmpDir, "test_"+string(rune('0'+numFiles/10))+".tar.gz") + if err := createTarGzFromDir(dataDir, archivePath); err != nil { + b.Fatal(err) + } + + return archivePath +} + +// createTarGzFromDir creates a tar.gz archive from a directory +func createTarGzFromDir(srcDir, dstPath string) error { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + rel, err := filepath.Rel(srcDir, path) + if err != nil { + return err + } + + hdr, err := tar.FileInfoHeader(info, "") + if err != nil { + return err + } + hdr.Name = rel + + if err := tw.WriteHeader(hdr); err != nil { + return err + } + + if !info.IsDir() { + data, err := os.ReadFile(path) + if err != nil { + return err + } + if _, err := tw.Write(data); err != nil { + return err + } + } + + return nil + }) + if err != nil { + return err + } + + if err := tw.Close(); err != nil { + return err + } + if err := gw.Close(); err != nil { + return err + } + + return os.WriteFile(dstPath, buf.Bytes(), 0640) +} diff --git a/tests/chaos/chaos_test.go b/tests/chaos/chaos_test.go index 7d00942..c82cfb7 100644 --- a/tests/chaos/chaos_test.go +++ b/tests/chaos/chaos_test.go @@ -14,6 +14,14 @@ import ( // ChaosTestSuite tests system resilience under various failure conditions func TestChaosTestSuite(t *testing.T) { + // Check Redis availability at suite level and warn if not available + quickRedis := redis.NewClient(&redis.Options{Addr: "localhost:6379", DB: 6}) + if err := quickRedis.Ping(context.Background()).Err(); err != nil { + t.Logf("WARNING: Redis not available at localhost:6379 - chaos tests will be skipped") + t.Logf(" To run these tests, start Redis: redis-server --port 6379") + } + quickRedis.Close() + // Tests that intentionally close/corrupt connections get their own resources // to prevent cascading failures to subsequent subtests @@ -69,6 +77,7 @@ func TestChaosTestSuite(t *testing.T) { rdb := setupChaosRedis(t) if rdb == nil { t.Skip("Redis not available for chaos tests") + return } defer func() { _ = rdb.Close() }() @@ -473,6 +482,7 @@ func setupChaosRedis(t *testing.T) *redis.Client { ctx := context.Background() if err := rdb.Ping(ctx).Err(); err != nil { + t.Logf("Skipping chaos test - Redis not available: %v", err) t.Skipf("Redis not available for chaos tests: %v", err) return nil } diff --git a/tests/e2e/docker-compose.logs-debug.yml b/tests/e2e/docker-compose.logs-debug.yml new file mode 100644 index 0000000..e5c8dc4 --- /dev/null +++ b/tests/e2e/docker-compose.logs-debug.yml @@ -0,0 +1,50 @@ +--- +# Docker Compose configuration for logs and debug E2E tests +# Simplified version using pre-built golang image with source mount + +services: + redis: + image: redis:7-alpine + ports: + - "6380:6379" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 3s + retries: 5 + + api-server: + image: golang:1.25-bookworm + working_dir: /app + command: > + sh -c " + go build -o api-server ./cmd/api-server/main.go && + ./api-server --config /app/configs/api/dev.yaml + " + ports: + - "9102:9101" + environment: + - LOG_LEVEL=debug + - REDIS_ADDR=redis:6379 + - FETCHML_NATIVE_LIBS=0 + volumes: + - ../../:/app + - api-logs:/logs + - api-experiments:/data/experiments + - api-active:/data/active + - go-mod-cache:/go/pkg/mod + depends_on: + redis: + condition: service_healthy + healthcheck: + test: ["CMD", "wget", "-q", "--spider", "http://localhost:9101/health"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 30s + +volumes: + api-logs: + api-experiments: + api-active: + go-mod-cache: diff --git a/tests/e2e/logs_debug_e2e_test.go b/tests/e2e/logs_debug_e2e_test.go new file mode 100644 index 0000000..1b17a8b --- /dev/null +++ b/tests/e2e/logs_debug_e2e_test.go @@ -0,0 +1,590 @@ +package tests + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +const ( + logsDebugComposeFile = "docker-compose.logs-debug.yml" + apiPort = "9102" // Use docker-compose port + apiHost = "localhost" +) + +// TestLogsDebugE2E tests the logs and debug WebSocket API end-to-end +func TestLogsDebugE2E(t *testing.T) { + if os.Getenv("FETCH_ML_E2E_LOGS_DEBUG") != "1" { + t.Skip("Skipping LogsDebugE2E (set FETCH_ML_E2E_LOGS_DEBUG=1 to enable)") + } + + composeFile := filepath.Join(".", logsDebugComposeFile) + if _, err := os.Stat(composeFile); os.IsNotExist(err) { + t.Skipf("Docker compose file not found: %s", composeFile) + } + + // Ensure Docker is available + if _, err := exec.LookPath("docker"); err != nil { + t.Skip("Docker not found in PATH") + } + if _, err := exec.LookPath("docker-compose"); err != nil { + t.Skip("docker-compose not found in PATH") + } + + t.Run("SetupAndLogsAPI", func(t *testing.T) { + testLogsAPIWithDockerCompose(t, composeFile) + }) + + t.Run("SetupAndDebugAPI", func(t *testing.T) { + testDebugAPIWithDockerCompose(t, composeFile) + }) +} + +// testLogsAPIWithDockerCompose tests the logs API using Docker Compose +func testLogsAPIWithDockerCompose(t *testing.T, composeFile string) { + // Cleanup any existing containers + cleanupDockerCompose(t, composeFile) + + // Start services + t.Log("Starting Docker Compose services for logs API test...") + upCmd := exec.CommandContext(context.Background(), + "docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "up", "-d", "--build") + upOutput, err := upCmd.CombinedOutput() + if err != nil { + t.Fatalf("Failed to start Docker Compose: %v\nOutput: %s", err, string(upOutput)) + } + + // Cleanup after test + defer func() { + cleanupDockerCompose(t, composeFile) + }() + + // Wait for services to be healthy + t.Log("Waiting for services to be healthy...") + if !waitForServicesHealthy(t, "logsdebug-e2e", 90*time.Second) { + t.Fatal("Services failed to become healthy") + } + + // Add delay to ensure server is fully ready for WebSocket connections + t.Log("Waiting additional 3 seconds for WebSocket to be ready...") + time.Sleep(3 * time.Second) + + // Test logs API + testTargetID := "ab" + testGetLogsAPI(t, testTargetID) + testStreamLogsAPI(t, testTargetID) +} + +// testDebugAPIWithDockerCompose tests the debug API using Docker Compose +func testDebugAPIWithDockerCompose(t *testing.T, composeFile string) { + // Cleanup any existing containers + cleanupDockerCompose(t, composeFile) + + // Start services + t.Log("Starting Docker Compose services for debug API test...") + upCmd := exec.CommandContext(context.Background(), + "docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "up", "-d", "--build") + upOutput, err := upCmd.CombinedOutput() + if err != nil { + t.Fatalf("Failed to start Docker Compose: %v\nOutput: %s", err, string(upOutput)) + } + + // Cleanup after test + defer func() { + cleanupDockerCompose(t, composeFile) + }() + + // Wait for services to be healthy + t.Log("Waiting for services to be healthy...") + if !waitForServicesHealthy(t, "logsdebug-e2e", 90*time.Second) { + t.Fatal("Services failed to become healthy") + } + + // Add delay to ensure server is fully ready for WebSocket connections + t.Log("Waiting additional 3 seconds for WebSocket to be ready...") + time.Sleep(3 * time.Second) + + testTargetID := "test-job-456" + testAttachDebugAPI(t, testTargetID, "interactive") + testAttachDebugAPI(t, testTargetID, "gdb") + testAttachDebugAPI(t, testTargetID, "pdb") +} + +// connectWebSocketWithRetry attempts to connect with exponential backoff +func connectWebSocketWithRetry(t *testing.T, wsURL string, maxRetries int) (*websocket.Conn, *http.Response, error) { + var conn *websocket.Conn + var resp *http.Response + var err error + + // Create dialer with nil headers (like existing working test) + dialer := websocket.DefaultDialer + dialer.HandshakeTimeout = 10 * time.Second + + for i := 0; i < maxRetries; i++ { + conn, resp, err = dialer.Dial(wsURL, nil) + if err == nil { + return conn, resp, nil + } + + // Log response details for debugging + if resp != nil { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + t.Logf("WebSocket connection attempt %d/%d failed: status=%d, body=%s", i+1, maxRetries, resp.StatusCode, string(body)) + } else { + t.Logf("WebSocket connection attempt %d/%d failed: %v", i+1, maxRetries, err) + } + + if i < maxRetries-1 { + delay := time.Duration(i+1) * 500 * time.Millisecond + t.Logf("Waiting %v before retry...", delay) + time.Sleep(delay) + } + } + + return nil, resp, err +} +func testGetLogsAPI(t *testing.T, targetID string) { + t.Logf("Testing get_logs API with target: %s", targetID) + + // First verify HTTP connectivity + healthURL := fmt.Sprintf("http://%s:%s/health", apiHost, apiPort) + resp, err := http.Get(healthURL) + if err != nil { + t.Fatalf("HTTP health check failed: %v", err) + } + resp.Body.Close() + t.Logf("HTTP health check passed: status=%d", resp.StatusCode) + + // Test WebSocket endpoint with HTTP GET (should return 400 or upgrade required) + wsHTTPURL := fmt.Sprintf("http://%s:%s/ws", apiHost, apiPort) + resp, err = http.Get(wsHTTPURL) + if err != nil { + t.Logf("HTTP GET to /ws failed: %v", err) + } else { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + t.Logf("HTTP GET /ws: status=%d, body=%s", resp.StatusCode, string(body)) + } + + wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort) + + // Connect to WebSocket with retry + conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + t.Fatalf("Failed to connect to WebSocket after retries: %v", err) + } + defer conn.Close() + + // Build binary message for get_logs + // [opcode:1][api_key_hash:16][target_id_len:1][target_id:var] + opcode := byte(0x20) // OpcodeGetLogs + apiKeyHash := make([]byte, 16) // Zero-filled for testing (no auth) + targetIDBytes := []byte(targetID) + + message := []byte{opcode} + message = append(message, apiKeyHash...) + message = append(message, byte(len(targetIDBytes))) + message = append(message, targetIDBytes...) + + t.Logf("Sending message: opcode=0x%02x, len=%d, payload_len=%d", opcode, len(message), len(message)-1) + + // Send message + err = conn.WriteMessage(websocket.BinaryMessage, message) + if err != nil { + t.Fatalf("Failed to send get_logs message: %v", err) + } + + // Read response + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + messageType, response, err := conn.ReadMessage() + if err != nil { + t.Fatalf("Failed to read get_logs response: %v", err) + } + + t.Logf("Received response type %d, length %d bytes", messageType, len(response)) + + // Parse response + if len(response) < 1 { + t.Fatal("Empty response received") + } + + // Parse the packet using the protocol + packet, err := parseResponsePacket(response) + if err != nil { + t.Fatalf("Failed to parse response packet: %v", err) + } + + // Verify response is a data packet (0x04 = PacketTypeData) + if packet.PacketType != 0x04 { + t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType) + } + + // Parse JSON payload + var logsResponse struct { + TargetID string `json:"target_id"` + Logs string `json:"logs"` + Truncated bool `json:"truncated"` + TotalLines int `json:"total_lines"` + } + + if err := json.Unmarshal(packet.Payload, &logsResponse); err != nil { + t.Fatalf("Failed to parse logs JSON response: %v", err) + } + + // Verify response fields + if logsResponse.TargetID != targetID { + t.Errorf("Expected target_id %s, got %s", targetID, logsResponse.TargetID) + } + + if logsResponse.Logs == "" { + t.Error("Expected non-empty logs content") + } + + t.Logf("Successfully received logs for target %s (%d lines, truncated=%v)", + logsResponse.TargetID, logsResponse.TotalLines, logsResponse.Truncated) +} + +// testStreamLogsAPI tests the stream_logs WebSocket endpoint +func testStreamLogsAPI(t *testing.T, targetID string) { + t.Logf("Testing stream_logs API with target: %s", targetID) + + wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort) + + // Connect to WebSocket with retry + conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + t.Fatalf("Failed to connect to WebSocket after retries: %v", err) + } + defer conn.Close() + + // Build binary message for stream_logs + // [opcode:1][api_key_hash:16][target_id_len:1][target_id:var] + opcode := byte(0x21) // OpcodeStreamLogs + apiKeyHash := make([]byte, 16) + targetIDBytes := []byte(targetID) + + message := []byte{opcode} + message = append(message, apiKeyHash...) + message = append(message, byte(len(targetIDBytes))) + message = append(message, targetIDBytes...) + + // Send message + err = conn.WriteMessage(websocket.BinaryMessage, message) + if err != nil { + t.Fatalf("Failed to send stream_logs message: %v", err) + } + + // Read response + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + messageType, response, err := conn.ReadMessage() + if err != nil { + t.Fatalf("Failed to read stream_logs response: %v", err) + } + + t.Logf("Received stream response type %d, length %d bytes", messageType, len(response)) + + // Parse response + packet, err := parseResponsePacket(response) + if err != nil { + t.Fatalf("Failed to parse response packet: %v", err) + } + + // Verify response + if packet.PacketType != 0x04 { // PacketTypeData + t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType) + } + + var streamResponse struct { + TargetID string `json:"target_id"` + Streaming bool `json:"streaming"` + Message string `json:"message"` + } + + if err := json.Unmarshal(packet.Payload, &streamResponse); err != nil { + t.Fatalf("Failed to parse stream JSON response: %v", err) + } + + if streamResponse.TargetID != targetID { + t.Errorf("Expected target_id %s, got %s", targetID, streamResponse.TargetID) + } + + if !streamResponse.Streaming { + t.Error("Expected streaming=true in response") + } + + t.Logf("Successfully initiated log streaming for target %s", streamResponse.TargetID) +} + +// testAttachDebugAPI tests the attach_debug WebSocket endpoint +func testAttachDebugAPI(t *testing.T, targetID, debugType string) { + t.Logf("Testing attach_debug API with target: %s, debug_type: %s", targetID, debugType) + + wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort) + + // Connect to WebSocket with retry + conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + t.Fatalf("Failed to connect to WebSocket after retries: %v", err) + } + defer conn.Close() + + // Build binary message for attach_debug + // [opcode:1][api_key_hash:16][target_id_len:1][target_id:var][debug_type:var] + opcode := byte(0x22) // OpcodeAttachDebug + apiKeyHash := make([]byte, 16) + targetIDBytes := []byte(targetID) + debugTypeBytes := []byte(debugType) + + message := []byte{opcode} + message = append(message, apiKeyHash...) + message = append(message, byte(len(targetIDBytes))) + message = append(message, targetIDBytes...) + message = append(message, debugTypeBytes...) + + // Send message + err = conn.WriteMessage(websocket.BinaryMessage, message) + if err != nil { + t.Fatalf("Failed to send attach_debug message: %v", err) + } + + // Read response + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + messageType, response, err := conn.ReadMessage() + if err != nil { + t.Fatalf("Failed to read attach_debug response: %v", err) + } + + t.Logf("Received debug response type %d, length %d bytes", messageType, len(response)) + + // Parse response + packet, err := parseResponsePacket(response) + if err != nil { + t.Fatalf("Failed to parse response packet: %v", err) + } + + // Verify response + if packet.PacketType != 0x04 { // PacketTypeData + t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType) + } + + var debugResponse struct { + TargetID string `json:"target_id"` + DebugType string `json:"debug_type"` + Attached bool `json:"attached"` + Message string `json:"message"` + Suggestion string `json:"suggestion"` + } + + if err := json.Unmarshal(packet.Payload, &debugResponse); err != nil { + t.Fatalf("Failed to parse debug JSON response: %v", err) + } + + if debugResponse.TargetID != targetID { + t.Errorf("Expected target_id %s, got %s", targetID, debugResponse.TargetID) + } + + if debugResponse.DebugType != debugType { + t.Errorf("Expected debug_type %s, got %s", debugType, debugResponse.DebugType) + } + + // Note: attached is false in stub implementation + t.Logf("Debug attachment response: target=%s, type=%s, attached=%v", + debugResponse.TargetID, debugResponse.DebugType, debugResponse.Attached) +} + +// ResponsePacket represents a parsed WebSocket response packet +type ResponsePacket struct { + PacketType byte + Payload []byte +} + +// parseResponsePacket parses a binary WebSocket response packet +func parseResponsePacket(data []byte) (*ResponsePacket, error) { + if len(data) < 9 { // min: type(1) + timestamp(8) + return nil, fmt.Errorf("packet too short: %d bytes", len(data)) + } + + packetType := data[0] + // Skip timestamp (8 bytes) + payloadStart := 9 + + // Handle different response types + switch packetType { + case 0x00: // PacketTypeSuccess + if len(data) <= payloadStart { + return &ResponsePacket{PacketType: packetType, Payload: nil}, nil + } + // Parse varint length + string + strLen, n := binary.Uvarint(data[payloadStart:]) + if n <= 0 { + return nil, fmt.Errorf("invalid varint") + } + return &ResponsePacket{ + PacketType: packetType, + Payload: data[payloadStart+n : payloadStart+n+int(strLen)], + }, nil + + case 0x01: // PacketTypeError + if len(data) < payloadStart+1 { + return nil, fmt.Errorf("error packet too short") + } + errorCode := data[payloadStart] + // Skip error code and read message + msgStart := payloadStart + 1 + if len(data) > msgStart && len(data) > msgStart+1 { + return nil, fmt.Errorf("server error (code: 0x%02x): %s", errorCode, string(data[msgStart:])) + } + return nil, fmt.Errorf("server error (code: 0x%02x)", errorCode) + + case 0x04: // PacketTypeData + // Format after timestamp: [data_type_len:varint][data_type][payload_len:varint][payload] + if len(data) <= payloadStart { + return &ResponsePacket{PacketType: packetType, Payload: nil}, nil + } + + // Read data_type length (varint) + typeLen, n1 := binary.Uvarint(data[payloadStart:]) + if n1 <= 0 { + return nil, fmt.Errorf("invalid type length varint") + } + + // Skip data_type string + payloadLenStart := payloadStart + n1 + int(typeLen) + if len(data) <= payloadLenStart { + return &ResponsePacket{PacketType: packetType, Payload: nil}, nil + } + + // Read payload length (varint) + payloadLen, n2 := binary.Uvarint(data[payloadLenStart:]) + if n2 <= 0 { + return nil, fmt.Errorf("invalid payload length varint") + } + + payloadDataStart := payloadLenStart + n2 + if len(data) < payloadDataStart+int(payloadLen) { + return nil, fmt.Errorf("data packet payload incomplete") + } + + return &ResponsePacket{ + PacketType: packetType, + Payload: data[payloadDataStart : payloadDataStart+int(payloadLen)], + }, nil + + default: + // Unknown packet type, return raw data after timestamp + if len(data) > payloadStart { + return &ResponsePacket{ + PacketType: packetType, + Payload: data[payloadStart:], + }, nil + } + return &ResponsePacket{PacketType: packetType}, nil + } +} + +// cleanupDockerCompose stops and removes Docker Compose containers +func cleanupDockerCompose(t *testing.T, composeFile string) { + t.Log("Cleaning up Docker Compose...") + downCmd := exec.CommandContext(context.Background(), + "docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "down", "--remove-orphans", "--volumes") + downCmd.Stdout = os.Stdout + downCmd.Stderr = os.Stderr + if err := downCmd.Run(); err != nil { + t.Logf("Warning: Failed to cleanup Docker Compose: %v", err) + } +} + +// waitForServicesHealthy waits for all services to be healthy +func waitForServicesHealthy(t *testing.T, projectName string, timeout time.Duration) bool { + start := time.Now() + for time.Since(start) < timeout { + // Check container status + psCmd := exec.CommandContext(context.Background(), + "docker", "ps", "--filter", fmt.Sprintf("name=%s", projectName), "--format", "{{.Names}}\t{{.Status}}") + output, err := psCmd.CombinedOutput() + if err == nil { + status := string(output) + t.Logf("Container status:\n%s", status) + + // Check if all containers are healthy (not just "Up") + allHealthy := true + lines := strings.Split(status, "\n") + containerCount := 0 + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + containerCount++ + // Must explicitly contain "healthy" - "Up (health: starting)" is NOT healthy + if !strings.Contains(line, "healthy") { + allHealthy = false + break + } + } + + if allHealthy && containerCount >= 2 { // redis and api-server + t.Log("All services are healthy") + return true + } + } + + time.Sleep(1 * time.Second) + } + + t.Logf("Timeout waiting for services after %v", timeout) + return false +} + +// TestLogsDebugHTTPHealthE2E tests the HTTP health endpoint for logs/debug services +func TestLogsDebugHTTPHealthE2E(t *testing.T) { + if os.Getenv("FETCH_ML_E2E_LOGS_DEBUG") != "1" { + t.Skip("Skipping LogsDebugHTTPHealthE2E (set FETCH_ML_E2E_LOGS_DEBUG=1 to enable)") + } + + // Check if API is already running + healthURL := fmt.Sprintf("http://%s:%s/health", apiHost, apiPort) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get(healthURL) + if err != nil { + t.Skipf("API server not available at %s: %v", healthURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Health check failed: status %d", resp.StatusCode) + } + + // Parse health response + var healthResponse struct { + Status string `json:"status"` + Version string `json:"version"` + } + + if err := json.NewDecoder(resp.Body).Decode(&healthResponse); err != nil { + t.Logf("Failed to parse health response: %v", err) + } else { + t.Logf("API health: status=%s, version=%s", healthResponse.Status, healthResponse.Version) + } +} diff --git a/tests/integration/duplicate_detection_test.go b/tests/integration/duplicate_detection_test.go new file mode 100644 index 0000000..7ed55af --- /dev/null +++ b/tests/integration/duplicate_detection_test.go @@ -0,0 +1,63 @@ +package tests + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/api/helpers" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// TestDuplicateDetection verifies the duplicate detection logic +func TestDuplicateDetection(t *testing.T) { + // Test 1: Same args + same commit = duplicate (would be detected) + commitID := "abc123def4567890abc1" + args1 := "--epochs 10 --lr 0.001" + args2 := "--epochs 10 --lr 0.001" // Same args + + hash1 := helpers.ComputeParamsHash(args1) + hash2 := helpers.ComputeParamsHash(args2) + + if hash1 != hash2 { + t.Error("Same args should produce same hash") + } + t.Logf("✓ Same args produce same hash: %s", hash1) + + // Test 2: Different args = not duplicate + args3 := "--epochs 20 --lr 0.01" // Different + hash3 := helpers.ComputeParamsHash(args3) + + if hash1 == hash3 { + t.Error("Different args should produce different hashes") + } + t.Logf("✓ Different args produce different hashes: %s vs %s", hash1, hash3) + + // Test 3: Same dataset specs = same dataset_id + ds1 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} + ds2 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} + + id1 := helpers.ComputeDatasetID(ds1, nil) + id2 := helpers.ComputeDatasetID(ds2, nil) + + if id1 != id2 { + t.Error("Same dataset specs should produce same ID") + } + t.Logf("✓ Same dataset specs produce same ID: %s", id1) + + // Test 4: Different dataset = different ID + ds3 := []queue.DatasetSpec{{Name: "cifar10", Checksum: "sha256:def456"}} + id3 := helpers.ComputeDatasetID(ds3, nil) + + if id1 == id3 { + t.Error("Different datasets should produce different IDs") + } + t.Logf("✓ Different datasets produce different IDs: %s vs %s", id1, id3) + + // Test 5: Composite key logic + t.Log("\n=== Composite Key Detection ===") + t.Logf("Job 1: commit=%s, dataset_id=%s, params_hash=%s", commitID, id1, hash1) + t.Logf("Job 2: commit=%s, dataset_id=%s, params_hash=%s", commitID, id2, hash2) + t.Log("→ These would be detected as DUPLICATE (same commit + dataset + params)") + + t.Logf("Job 3: commit=%s, dataset_id=%s, params_hash=%s", commitID, id1, hash3) + t.Log("→ This would NOT be duplicate (different params)") +} diff --git a/tests/jupyter_experiment_integration_test.go b/tests/integration/jupyter_experiment_test.go similarity index 100% rename from tests/jupyter_experiment_integration_test.go rename to tests/integration/jupyter_experiment_test.go diff --git a/tests/integration_protocol_test.go b/tests/integration/protocol_test.go similarity index 100% rename from tests/integration_protocol_test.go rename to tests/integration/protocol_test.go diff --git a/tests/integration/websocket_queue_integration_test.go b/tests/integration/websocket_queue_integration_test.go index 797f8bd..8254439 100644 --- a/tests/integration/websocket_queue_integration_test.go +++ b/tests/integration/websocket_queue_integration_test.go @@ -236,10 +236,10 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) { expMgr, "", taskQueue, - nil, - nil, - nil, - nil, + nil, // db + nil, // jupyterServiceMgr + nil, // securityConfig + nil, // auditLogger ) server := httptest.NewServer(wsHandler) defer server.Close() diff --git a/tests/integration/ws_handler_integration_test.go b/tests/integration/ws_handler_integration_test.go index 556f4f0..ffff270 100644 --- a/tests/integration/ws_handler_integration_test.go +++ b/tests/integration/ws_handler_integration_test.go @@ -65,9 +65,9 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) ( dataDir, tq, db, - nil, - nil, - nil, + nil, // jupyterServiceMgr + nil, // securityConfig + nil, // auditLogger ) server := httptest.NewServer(handler) return server, tq, expManager, s, db diff --git a/tests/unit/api/duplicate_detection_process_test.go b/tests/unit/api/duplicate_detection_process_test.go new file mode 100644 index 0000000..42a08ef --- /dev/null +++ b/tests/unit/api/duplicate_detection_process_test.go @@ -0,0 +1,95 @@ +package api_test + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/api/helpers" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// ProcessTest demonstrates the duplicate detection process step by step +func TestDuplicateDetectionProcess(t *testing.T) { + t.Log("=== Duplicate Detection Process Test ===") + + // Step 1: First job submission + t.Log("\n1. First job submission:") + commitID := "abc123def456" + args1 := "--epochs 10 --lr 0.001" + datasets := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} + + datasetID1 := helpers.ComputeDatasetID(datasets, nil) + paramsHash1 := helpers.ComputeParamsHash(args1) + + t.Logf(" Commit ID: %s", commitID) + t.Logf(" Dataset ID: %s (computed from %d datasets)", datasetID1, len(datasets)) + t.Logf(" Params Hash: %s (computed from args: %s)", paramsHash1, args1) + t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID1, paramsHash1) + + // Step 2: Second job with SAME parameters (should be duplicate) + t.Log("\n2. Second job submission (same params):") + args2 := "--epochs 10 --lr 0.001" // Same args + datasets2 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} // Same dataset + + datasetID2 := helpers.ComputeDatasetID(datasets2, nil) + paramsHash2 := helpers.ComputeParamsHash(args2) + + t.Logf(" Commit ID: %s", commitID) + t.Logf(" Dataset ID: %s", datasetID2) + t.Logf(" Params Hash: %s", paramsHash2) + t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID2, paramsHash2) + + // Verify they're the same + if datasetID1 == datasetID2 && paramsHash1 == paramsHash2 { + t.Log(" ✓ DUPLICATE DETECTED - same composite key!") + } else { + t.Error(" ✗ Should have been detected as duplicate") + } + + // Step 3: Third job with DIFFERENT parameters (not duplicate) + t.Log("\n3. Third job submission (different params):") + args3 := "--epochs 20 --lr 0.01" // Different args + datasets3 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} // Same dataset + + datasetID3 := helpers.ComputeDatasetID(datasets3, nil) + paramsHash3 := helpers.ComputeParamsHash(args3) + + t.Logf(" Commit ID: %s", commitID) + t.Logf(" Dataset ID: %s", datasetID3) + t.Logf(" Params Hash: %s", paramsHash3) + t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID3, paramsHash3) + + // Verify they're different + if paramsHash1 != paramsHash3 { + t.Log(" ✓ NOT A DUPLICATE - different params_hash") + } else { + t.Error(" ✗ Should have different params_hash") + } + + // Step 4: Fourth job with DIFFERENT dataset (not duplicate) + t.Log("\n4. Fourth job submission (different dataset):") + args4 := "--epochs 10 --lr 0.001" // Same args + datasets4 := []queue.DatasetSpec{{Name: "cifar10", Checksum: "sha256:def456"}} // Different dataset + + datasetID4 := helpers.ComputeDatasetID(datasets4, nil) + paramsHash4 := helpers.ComputeParamsHash(args4) + + t.Logf(" Commit ID: %s", commitID) + t.Logf(" Dataset ID: %s", datasetID4) + t.Logf(" Params Hash: %s", paramsHash4) + t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID4, paramsHash4) + + // Verify they're different + if datasetID1 != datasetID4 { + t.Log(" ✓ NOT A DUPLICATE - different dataset_id") + } else { + t.Error(" ✗ Should have different dataset_id") + } + + // Step 5: Summary + t.Log("\n5. Summary:") + t.Log(" - Jobs 1 & 2: Same commit_id + dataset_id + params_hash = DUPLICATE") + t.Log(" - Job 3: Different params_hash = NOT DUPLICATE") + t.Log(" - Job 4: Different dataset_id = NOT DUPLICATE") + t.Log("\n The composite key (commit_id, dataset_id, params_hash) ensures") + t.Log(" only truly identical experiments are flagged as duplicates.") +} diff --git a/tests/unit/api/helpers/db_helpers_test.go b/tests/unit/api/helpers/db_helpers_test.go new file mode 100644 index 0000000..5ab371f --- /dev/null +++ b/tests/unit/api/helpers/db_helpers_test.go @@ -0,0 +1,225 @@ +package helpers + +import ( + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/api/helpers" +) + +func TestDBContext(t *testing.T) { + ctx, cancel := helpers.DBContext(3 * time.Second) + defer cancel() + + deadline, ok := ctx.Deadline() + if !ok { + t.Error("expected deadline to be set") + } + + // Deadline should be within a reasonable time window + now := time.Now() + if deadline.Before(now) { + t.Error("deadline is in the past") + } + if deadline.After(now.Add(4 * time.Second)) { + t.Error("deadline is too far in the future") + } + + // Context should not be cancelled + select { + case <-ctx.Done(): + t.Error("context should not be cancelled yet") + default: + // Good + } +} + +func TestDBContextShort(t *testing.T) { + ctx, cancel := helpers.DBContextShort() + defer cancel() + + deadline, ok := ctx.Deadline() + if !ok { + t.Error("expected deadline to be set") + } + + now := time.Now() + diff := deadline.Sub(now) + + // Should be around 3 seconds + if diff < 2*time.Second || diff > 4*time.Second { + t.Errorf("expected deadline ~3s, got %v", diff) + } +} + +func TestDBContextMedium(t *testing.T) { + ctx, cancel := helpers.DBContextMedium() + defer cancel() + + deadline, ok := ctx.Deadline() + if !ok { + t.Error("expected deadline to be set") + } + + now := time.Now() + diff := deadline.Sub(now) + + // Should be around 5 seconds + if diff < 4*time.Second || diff > 6*time.Second { + t.Errorf("expected deadline ~5s, got %v", diff) + } +} + +func TestDBContextLong(t *testing.T) { + ctx, cancel := helpers.DBContextLong() + defer cancel() + + deadline, ok := ctx.Deadline() + if !ok { + t.Error("expected deadline to be set") + } + + now := time.Now() + diff := deadline.Sub(now) + + // Should be around 10 seconds + if diff < 9*time.Second || diff > 11*time.Second { + t.Errorf("expected deadline ~10s, got %v", diff) + } +} + +func TestDBContextCancellation(t *testing.T) { + ctx, cancel := helpers.DBContextShort() + + // Cancel immediately + cancel() + + // Context should be cancelled + select { + case <-ctx.Done(): + // Good - context was cancelled + default: + t.Error("context should be cancelled after calling cancel()") + } + + // Check error + if ctx.Err() == nil { + t.Error("expected non-nil error after cancellation") + } +} + +func TestStringSliceContains(t *testing.T) { + tests := []struct { + name string + slice []string + item string + want bool + }{ + { + name: "contains", + slice: []string{"a", "b", "c"}, + item: "b", + want: true, + }, + { + name: "not contains", + slice: []string{"a", "b", "c"}, + item: "d", + want: false, + }, + { + name: "empty slice", + slice: []string{}, + item: "a", + want: false, + }, + { + name: "nil slice", + slice: nil, + item: "a", + want: false, + }, + { + name: "empty string item", + slice: []string{"a", "", "c"}, + item: "", + want: true, + }, + { + name: "case sensitive", + slice: []string{"Apple", "Banana"}, + item: "apple", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.StringSliceContains(tt.slice, tt.item) + if got != tt.want { + t.Errorf("StringSliceContains() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStringSliceFilter(t *testing.T) { + tests := []struct { + name string + slice []string + predicate func(string) bool + want []string + }{ + { + name: "filter by prefix", + slice: []string{"apple", "banana", "apricot", "cherry"}, + predicate: func(s string) bool { return len(s) > 0 && s[0] == 'a' }, + want: []string{"apple", "apricot"}, + }, + { + name: "filter empty strings", + slice: []string{"a", "", "b", "", "c"}, + predicate: func(s string) bool { return s != "" }, + want: []string{"a", "b", "c"}, + }, + { + name: "all match", + slice: []string{"a", "b", "c"}, + predicate: func(s string) bool { return true }, + want: []string{"a", "b", "c"}, + }, + { + name: "none match", + slice: []string{"a", "b", "c"}, + predicate: func(s string) bool { return false }, + want: []string{}, + }, + { + name: "empty slice", + slice: []string{}, + predicate: func(s string) bool { return true }, + want: []string{}, + }, + { + name: "nil slice", + slice: nil, + predicate: func(s string) bool { return true }, + want: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.StringSliceFilter(tt.slice, tt.predicate) + if len(got) != len(tt.want) { + t.Errorf("StringSliceFilter() returned %d items, want %d", len(got), len(tt.want)) + return + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("StringSliceFilter()[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +} diff --git a/tests/unit/api/helpers/hash_helpers_test.go b/tests/unit/api/helpers/hash_helpers_test.go new file mode 100644 index 0000000..5f8ca97 --- /dev/null +++ b/tests/unit/api/helpers/hash_helpers_test.go @@ -0,0 +1,137 @@ +package helpers_test + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/api/helpers" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +func TestComputeDatasetID(t *testing.T) { + tests := []struct { + name string + datasetSpecs []queue.DatasetSpec + datasets []string + want string + }{ + { + name: "both empty", + datasetSpecs: nil, + datasets: nil, + want: "", + }, + { + name: "only datasets", + datasetSpecs: nil, + datasets: []string{"dataset1", "dataset2"}, + want: "", // will be a hash + }, + { + name: "dataset specs with checksums", + datasetSpecs: []queue.DatasetSpec{ + {Name: "ds1", Checksum: "abc123"}, + {Name: "ds2", Checksum: "def456"}, + }, + datasets: nil, + want: "", // will be a hash + }, + { + name: "dataset specs without checksums", + datasetSpecs: []queue.DatasetSpec{ + {Name: "ds1"}, + {Name: "ds2"}, + }, + datasets: nil, + want: "", // will use names + }, + { + name: "checksums take precedence", + datasetSpecs: []queue.DatasetSpec{{Name: "ds1", Checksum: "xyz789"}}, + datasets: []string{"dataset1"}, + want: "", // should use checksum from specs + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.ComputeDatasetID(tt.datasetSpecs, tt.datasets) + if tt.want == "" { + // Just verify it returns something or empty as expected + if len(tt.datasetSpecs) == 0 && len(tt.datasets) == 0 && got != "" { + t.Errorf("ComputeDatasetID() = %q, want empty string", got) + } + } else if got != tt.want { + t.Errorf("ComputeDatasetID() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestComputeParamsHash(t *testing.T) { + tests := []struct { + name string + args string + want string + }{ + { + name: "empty args", + args: "", + want: "", + }, + { + name: "simple args", + args: "--lr 0.01 --epochs 10", + want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // sha256 of trimmed args + }, + { + name: "args with spaces", + args: " --lr 0.01 ", + want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // sha256 of trimmed args + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.ComputeParamsHash(tt.args) + // Just verify it returns a string (hash computation is deterministic) + if tt.want == "" && got != "" { + t.Errorf("ComputeParamsHash() expected empty, got %q", got) + } + if tt.want != "" && got == "" { + t.Errorf("ComputeParamsHash() expected non-empty hash") + } + }) + } +} + +func TestComputeParamsHash_Deterministic(t *testing.T) { + args := "--lr 0.01 --epochs 10" + hash1 := helpers.ComputeParamsHash(args) + hash2 := helpers.ComputeParamsHash(args) + + if hash1 != hash2 { + t.Error("ComputeParamsHash() should be deterministic") + } + + // Different args should produce different hashes + differentArgs := "--lr 0.02 --epochs 10" + differentHash := helpers.ComputeParamsHash(differentArgs) + + if hash1 == differentHash { + t.Error("ComputeParamsHash() should produce different hashes for different inputs") + } +} + +func TestComputeParamsHash_Whitespace(t *testing.T) { + // Same args with different whitespace should produce the same hash + hash1 := helpers.ComputeParamsHash("--lr 0.01 --epochs 10") + hash2 := helpers.ComputeParamsHash(" --lr 0.01 --epochs 10 ") + hash3 := helpers.ComputeParamsHash("--lr 0.01 --epochs 10") + + if hash1 != hash2 { + t.Error("ComputeParamsHash() should handle leading/trailing whitespace consistently") + } + // Note: internal whitespace differences may or may not produce different hashes + // depending on implementation details + _ = hash3 +} diff --git a/tests/unit/api/helpers/payload_parser_test.go b/tests/unit/api/helpers/payload_parser_test.go new file mode 100644 index 0000000..0cc29e9 --- /dev/null +++ b/tests/unit/api/helpers/payload_parser_test.go @@ -0,0 +1,451 @@ +package helpers_test + +import ( + "bytes" + "testing" + + "github.com/jfraeys/fetch_ml/internal/api/helpers" +) + +func TestNewPayloadParser(t *testing.T) { + payload := []byte{0x01, 0x02, 0x03, 0x04, 0x05} + p := helpers.NewPayloadParser(payload, 2) + + if p.Offset() != 2 { + t.Errorf("expected offset 2, got %d", p.Offset()) + } + if !bytes.Equal(p.Payload(), payload) { + t.Error("payload mismatch") + } +} + +func TestPayloadParser_ParseByte(t *testing.T) { + tests := []struct { + name string + payload []byte + start int + want byte + wantErr bool + }{ + { + name: "valid byte", + payload: []byte{0x00, 0x01, 0x02}, + start: 1, + want: 0x01, + wantErr: false, + }, + { + name: "end of payload", + payload: []byte{0x00, 0x01}, + start: 2, + want: 0, + wantErr: true, + }, + { + name: "empty payload", + payload: []byte{}, + start: 0, + want: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := helpers.NewPayloadParser(tt.payload, tt.start) + got, err := p.ParseByte() + if (err != nil) != tt.wantErr { + t.Errorf("ParseByte() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ParseByte() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPayloadParser_ParseUint16(t *testing.T) { + tests := []struct { + name string + payload []byte + start int + want uint16 + wantErr bool + }{ + { + name: "valid uint16", + payload: []byte{0x01, 0x02, 0x03, 0x04}, + start: 0, + want: 0x0102, + wantErr: false, + }, + { + name: "another valid uint16", + payload: []byte{0x00, 0xAB, 0xCD}, + start: 1, + want: 0xABCD, + wantErr: false, + }, + { + name: "too short", + payload: []byte{0x00, 0x01}, + start: 1, + want: 0, + wantErr: true, + }, + { + name: "empty at end", + payload: []byte{0x00, 0x01, 0x02}, + start: 3, + want: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := helpers.NewPayloadParser(tt.payload, tt.start) + got, err := p.ParseUint16() + if (err != nil) != tt.wantErr { + t.Errorf("ParseUint16() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ParseUint16() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPayloadParser_ParseLengthPrefixedString(t *testing.T) { + tests := []struct { + name string + payload []byte + start int + want string + wantErr bool + }{ + { + name: "valid string", + payload: []byte{0x05, 'h', 'e', 'l', 'l', 'o'}, + start: 0, + want: "hello", + wantErr: false, + }, + { + name: "empty string", + payload: []byte{0x00, 0x01, 0x02}, + start: 0, + want: "", + wantErr: false, + }, + { + name: "from offset", + payload: []byte{0x00, 0x03, 'f', 'o', 'o'}, + start: 1, + want: "foo", + wantErr: false, + }, + { + name: "too short for length", + payload: []byte{0x05, 'h', 'i'}, + start: 0, + want: "", + wantErr: true, + }, + { + name: "no length byte", + payload: []byte{}, + start: 0, + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := helpers.NewPayloadParser(tt.payload, tt.start) + got, err := p.ParseLengthPrefixedString() + if (err != nil) != tt.wantErr { + t.Errorf("ParseLengthPrefixedString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ParseLengthPrefixedString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPayloadParser_ParseUint16PrefixedString(t *testing.T) { + tests := []struct { + name string + payload []byte + start int + want string + wantErr bool + }{ + { + name: "valid string", + payload: []byte{0x00, 0x05, 'h', 'e', 'l', 'l', 'o'}, + start: 0, + want: "hello", + wantErr: false, + }, + { + name: "empty string", + payload: []byte{0x00, 0x00, 0x01, 0x02}, + start: 0, + want: "", + wantErr: false, + }, + { + name: "from offset", + payload: []byte{0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r'}, + start: 2, + want: "bar", + wantErr: false, + }, + { + name: "too short for length", + payload: []byte{0x00}, + start: 0, + want: "", + wantErr: true, + }, + { + name: "too short for string", + payload: []byte{0x00, 0x05, 'h', 'i'}, + start: 0, + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := helpers.NewPayloadParser(tt.payload, tt.start) + got, err := p.ParseUint16PrefixedString() + if (err != nil) != tt.wantErr { + t.Errorf("ParseUint16PrefixedString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ParseUint16PrefixedString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPayloadParser_ParseBool(t *testing.T) { + tests := []struct { + name string + payload []byte + start int + want bool + wantErr bool + }{ + { + name: "true", + payload: []byte{0x01, 0x00}, + start: 0, + want: true, + wantErr: false, + }, + { + name: "false", + payload: []byte{0x00, 0x01}, + start: 0, + want: false, + wantErr: false, + }, + { + name: "non-zero is true", + payload: []byte{0xFF}, + start: 0, + want: true, + wantErr: false, + }, + { + name: "empty", + payload: []byte{}, + start: 0, + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := helpers.NewPayloadParser(tt.payload, tt.start) + got, err := p.ParseBool() + if (err != nil) != tt.wantErr { + t.Errorf("ParseBool() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ParseBool() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPayloadParser_ParseFixedBytes(t *testing.T) { + tests := []struct { + name string + payload []byte + start int + length int + want []byte + wantErr bool + }{ + { + name: "valid", + payload: []byte{0x00, 0x01, 0x02, 0x03, 0x04}, + start: 1, + length: 3, + want: []byte{0x01, 0x02, 0x03}, + wantErr: false, + }, + { + name: "exact length", + payload: []byte{0x00, 0x01, 0x02}, + start: 0, + length: 3, + want: []byte{0x00, 0x01, 0x02}, + wantErr: false, + }, + { + name: "too long", + payload: []byte{0x00, 0x01}, + start: 0, + length: 3, + want: nil, + wantErr: true, + }, + { + name: "from end", + payload: []byte{0x00, 0x01, 0x02}, + start: 2, + length: 2, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := helpers.NewPayloadParser(tt.payload, tt.start) + got, err := p.ParseFixedBytes(tt.length) + if (err != nil) != tt.wantErr { + t.Errorf("ParseFixedBytes() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !bytes.Equal(got, tt.want) { + t.Errorf("ParseFixedBytes() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPayloadParser_Offset(t *testing.T) { + payload := []byte{0x00, 0x01, 0x02, 0x03} + p := helpers.NewPayloadParser(payload, 1) + + if p.Offset() != 1 { + t.Errorf("initial Offset() = %d, want 1", p.Offset()) + } + + p.ParseByte() + if p.Offset() != 2 { + t.Errorf("after ParseByte() Offset() = %d, want 2", p.Offset()) + } +} + +func TestPayloadParser_Remaining(t *testing.T) { + payload := []byte{0x00, 0x01, 0x02, 0x03} + p := helpers.NewPayloadParser(payload, 2) + + remaining := p.Remaining() + if !bytes.Equal(remaining, []byte{0x02, 0x03}) { + t.Errorf("Remaining() = %v, want [0x02, 0x03]", remaining) + } + + // At end + p = helpers.NewPayloadParser(payload, 4) + if p.Remaining() != nil { + t.Errorf("at end Remaining() = %v, want nil", p.Remaining()) + } + + // Beyond end + p = helpers.NewPayloadParser(payload, 5) + if p.Remaining() != nil { + t.Errorf("beyond end Remaining() = %v, want nil", p.Remaining()) + } +} + +func TestPayloadParser_HasRemaining(t *testing.T) { + payload := []byte{0x00, 0x01} + + tests := []struct { + start int + want bool + }{ + {0, true}, + {1, true}, + {2, false}, + {3, false}, + } + + for _, tt := range tests { + p := helpers.NewPayloadParser(payload, tt.start) + if got := p.HasRemaining(); got != tt.want { + t.Errorf("HasRemaining() with start=%d = %v, want %v", tt.start, got, tt.want) + } + } +} + +func TestPayloadParser_ChainedParsing(t *testing.T) { + // Simulate a real protocol: [api_key:2][name_len:1][name:var][value:2] + payload := []byte{ + 0xAB, 0xCD, // api_key + 0x05, // name_len + 'h', 'e', 'l', 'l', 'o', // name + 0x00, 0x42, // value + } + + p := helpers.NewPayloadParser(payload, 0) + + // Parse api key (2 bytes) + apiKey, err := p.ParseFixedBytes(2) + if err != nil { + t.Fatalf("failed to parse api key: %v", err) + } + if !bytes.Equal(apiKey, []byte{0xAB, 0xCD}) { + t.Errorf("api key = %v, want [0xAB, 0xCD]", apiKey) + } + + // Parse name + name, err := p.ParseLengthPrefixedString() + if err != nil { + t.Fatalf("failed to parse name: %v", err) + } + if name != "hello" { + t.Errorf("name = %q, want 'hello'", name) + } + + // Parse value + value, err := p.ParseUint16() + if err != nil { + t.Fatalf("failed to parse value: %v", err) + } + if value != 0x0042 { + t.Errorf("value = 0x%X, want 0x0042", value) + } + + // Should be at end + if p.HasRemaining() { + t.Error("expected no remaining bytes") + } +} diff --git a/tests/unit/api/helpers/response_helpers_test.go b/tests/unit/api/helpers/response_helpers_test.go new file mode 100644 index 0000000..9c580ca --- /dev/null +++ b/tests/unit/api/helpers/response_helpers_test.go @@ -0,0 +1,345 @@ +package helpers + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/api/helpers" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +func TestNewTaskErrorMapper(t *testing.T) { + mapper := helpers.NewTaskErrorMapper() + if mapper == nil { + t.Error("NewTaskErrorMapper() returned nil") + } +} + +func TestTaskErrorMapper_MapJupyterError(t *testing.T) { + mapper := helpers.NewTaskErrorMapper() + + tests := []struct { + name string + task *queue.Task + want helpers.ErrorCode + }{ + { + name: "nil task", + task: nil, + want: 0x00, // ErrorCodeUnknownError + }, + { + name: "cancelled task", + task: &queue.Task{Status: "cancelled"}, + want: 0x24, // ErrorCodeJobCancelled + }, + { + name: "oom error", + task: &queue.Task{Status: "failed", Error: "out of memory"}, + want: 0x30, // ErrorCodeOutOfMemory + }, + { + name: "oom shorthand", + task: &queue.Task{Status: "failed", Error: "OOM killed"}, + want: 0x30, // ErrorCodeOutOfMemory + }, + { + name: "disk full", + task: &queue.Task{Status: "failed", Error: "no space left on device"}, + want: 0x31, // ErrorCodeDiskFull + }, + { + name: "disk full alt", + task: &queue.Task{Status: "failed", Error: "disk full"}, + want: 0x31, // ErrorCodeDiskFull + }, + { + name: "rate limit", + task: &queue.Task{Status: "failed", Error: "rate limit exceeded"}, + want: 0x33, // ErrorCodeServiceUnavailable + }, + { + name: "throttle", + task: &queue.Task{Status: "failed", Error: "request throttled"}, + want: 0x33, // ErrorCodeServiceUnavailable + }, + { + name: "timeout", + task: &queue.Task{Status: "failed", Error: "timed out waiting"}, + want: 0x14, // ErrorCodeTimeout + }, + { + name: "deadline", + task: &queue.Task{Status: "failed", Error: "context deadline exceeded"}, + want: 0x14, // ErrorCodeTimeout + }, + { + name: "connection refused", + task: &queue.Task{Status: "failed", Error: "connection refused"}, + want: 0x12, // ErrorCodeNetworkError + }, + { + name: "connection reset", + task: &queue.Task{Status: "failed", Error: "connection reset by peer"}, + want: 0x12, // ErrorCodeNetworkError + }, + { + name: "network unreachable", + task: &queue.Task{Status: "failed", Error: "network unreachable"}, + want: 0x12, // ErrorCodeNetworkError + }, + { + name: "queue not configured", + task: &queue.Task{Status: "failed", Error: "queue not configured"}, + want: 0x32, // ErrorCodeInvalidConfiguration + }, + { + name: "generic failed", + task: &queue.Task{Status: "failed", Error: "something went wrong"}, + want: 0x23, // ErrorCodeJobExecutionFailed + }, + { + name: "unknown status", + task: &queue.Task{Status: "unknown", Error: "unknown error"}, + want: 0x00, // ErrorCodeUnknownError + }, + { + name: "case insensitive - cancelled", + task: &queue.Task{Status: "CANCELLED"}, + want: 0x24, // ErrorCodeJobCancelled + }, + { + name: "case insensitive - oom", + task: &queue.Task{Status: "FAILED", Error: "OUT OF MEMORY"}, + want: 0x30, // ErrorCodeOutOfMemory + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mapper.MapJupyterError(tt.task) + if got != tt.want { + t.Errorf("MapJupyterError() = 0x%02X, want 0x%02X", got, tt.want) + } + }) + } +} + +func TestTaskErrorMapper_MapError(t *testing.T) { + mapper := helpers.NewTaskErrorMapper() + + tests := []struct { + name string + task *queue.Task + defaultCode helpers.ErrorCode + want helpers.ErrorCode + }{ + { + name: "nil task returns default", + task: nil, + defaultCode: 0x11, // ErrorCodeDatabaseError + want: 0x11, + }, + { + name: "cancelled", + task: &queue.Task{Status: "cancelled"}, + defaultCode: 0x00, + want: 0x24, // ErrorCodeJobCancelled + }, + { + name: "oom", + task: &queue.Task{Status: "failed", Error: "oom"}, + defaultCode: 0x00, + want: 0x30, // ErrorCodeOutOfMemory + }, + { + name: "timeout", + task: &queue.Task{Status: "failed", Error: "timeout"}, + defaultCode: 0x00, + want: 0x14, // ErrorCodeTimeout + }, + { + name: "generic failed", + task: &queue.Task{Status: "failed", Error: "generic error"}, + defaultCode: 0x00, + want: 0x23, // ErrorCodeJobExecutionFailed + }, + { + name: "unknown status returns default", + task: &queue.Task{Status: "weird", Error: "unknown"}, + defaultCode: 0x01, + want: 0x01, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mapper.MapError(tt.task, tt.defaultCode) + if got != tt.want { + t.Errorf("MapError() = 0x%02X, want 0x%02X", got, tt.want) + } + }) + } +} + +func TestParseResourceRequest(t *testing.T) { + tests := []struct { + name string + payload []byte + want *helpers.ResourceRequest + wantErr bool + }{ + { + name: "empty payload", + payload: []byte{}, + want: nil, + wantErr: false, + }, + { + name: "nil payload", + payload: nil, + want: nil, + wantErr: false, + }, + { + name: "valid minimal", + payload: []byte{4, 8, 1, 0}, + want: &helpers.ResourceRequest{CPU: 4, MemoryGB: 8, GPU: 1, GPUMemory: ""}, + wantErr: false, + }, + { + name: "valid with gpu memory", + payload: []byte{8, 16, 2, 4, '8', 'G', 'B', '!'}, + want: &helpers.ResourceRequest{CPU: 8, MemoryGB: 16, GPU: 2, GPUMemory: "8GB!"}, + wantErr: false, + }, + { + name: "too short", + payload: []byte{1, 2}, + want: nil, + wantErr: true, + }, + { + name: "invalid gpu mem length", + payload: []byte{1, 2, 1, 10, 'a'}, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := helpers.ParseResourceRequest(tt.payload) + if (err != nil) != tt.wantErr { + t.Errorf("ParseResourceRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.want == nil { + if got != nil { + t.Errorf("ParseResourceRequest() = %v, want nil", got) + } + } else if got == nil { + t.Errorf("ParseResourceRequest() = nil, want %v", tt.want) + } else { + if got.CPU != tt.want.CPU { + t.Errorf("CPU = %d, want %d", got.CPU, tt.want.CPU) + } + if got.MemoryGB != tt.want.MemoryGB { + t.Errorf("MemoryGB = %d, want %d", got.MemoryGB, tt.want.MemoryGB) + } + if got.GPU != tt.want.GPU { + t.Errorf("GPU = %d, want %d", got.GPU, tt.want.GPU) + } + if got.GPUMemory != tt.want.GPUMemory { + t.Errorf("GPUMemory = %q, want %q", got.GPUMemory, tt.want.GPUMemory) + } + } + }) + } +} + +func TestMarshalJSONOrEmpty(t *testing.T) { + tests := []struct { + name string + data interface{} + want []byte + }{ + { + name: "simple map", + data: map[string]string{"key": "value"}, + want: []byte(`{"key":"value"}`), + }, + { + name: "string slice", + data: []string{"a", "b", "c"}, + want: []byte(`["a","b","c"]`), + }, + { + name: "empty slice", + data: []int{}, + want: []byte(`[]`), + }, + { + name: "nil", + data: nil, + want: []byte(`null`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.MarshalJSONOrEmpty(tt.data) + // For valid JSON, compare strings since JSON formatting might vary + gotStr := string(got) + wantStr := string(tt.want) + if gotStr != wantStr { + t.Errorf("MarshalJSONOrEmpty() = %q, want %q", gotStr, wantStr) + } + }) + } +} + +func TestMarshalJSONOrEmpty_ErrorCase(t *testing.T) { + // Test with a value that can't be marshaled (function) + got := helpers.MarshalJSONOrEmpty(func() {}) + want := []byte("[]") + if string(got) != string(want) { + t.Errorf("MarshalJSONOrEmpty() with invalid data = %q, want %q", string(got), string(want)) + } +} + +func TestMarshalJSONBytes(t *testing.T) { + data := map[string]int{"count": 42} + got, err := helpers.MarshalJSONBytes(data) + if err != nil { + t.Errorf("MarshalJSONBytes() unexpected error: %v", err) + } + want := `{"count":42}` + if string(got) != want { + t.Errorf("MarshalJSONBytes() = %q, want %q", string(got), want) + } +} + +func TestIsEmptyJSON(t *testing.T) { + tests := []struct { + name string + data []byte + want bool + }{ + {"empty", []byte{}, true}, + {"null", []byte("null"), true}, + {"empty array", []byte("[]"), true}, + {"empty object", []byte("{}"), true}, + {"whitespace", []byte(" "), true}, + {"data", []byte(`{"key":"value"}`), false}, + {"non-empty array", []byte("[1,2,3]"), false}, + {"null with spaces", []byte(" null "), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := helpers.IsEmptyJSON(tt.data); got != tt.want { + t.Errorf("IsEmptyJSON() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tests/unit/api/helpers/validation_helpers_test.go b/tests/unit/api/helpers/validation_helpers_test.go new file mode 100644 index 0000000..4a62b71 --- /dev/null +++ b/tests/unit/api/helpers/validation_helpers_test.go @@ -0,0 +1,486 @@ +package helpers + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/api/helpers" + "github.com/jfraeys/fetch_ml/internal/manifest" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +func TestValidateCommitIDFormat(t *testing.T) { + tests := []struct { + name string + commitID string + wantOk bool + wantErr string + }{ + { + name: "valid commit ID", + commitID: "aabbccddeeff00112233445566778899aabbccdd", + wantOk: true, + wantErr: "", + }, + { + name: "too short", + commitID: "aabbcc", + wantOk: false, + wantErr: "invalid commit_id length", + }, + { + name: "too long", + commitID: "aabbccddeeff00112233445566778899aabbccddeeff", + wantOk: false, + wantErr: "invalid commit_id length", + }, + { + name: "invalid hex", + commitID: "gggggggggggggggggggggggggggggggggggggggg", + wantOk: false, + wantErr: "invalid commit_id hex", + }, + { + name: "mixed case valid", + commitID: "AABBCCDDEEFF00112233445566778899AABBCCDD", + wantOk: true, + wantErr: "", + }, + { + name: "empty", + commitID: "", + wantOk: false, + wantErr: "invalid commit_id length", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOk, gotErr := helpers.ValidateCommitIDFormat(tt.commitID) + if gotOk != tt.wantOk { + t.Errorf("ValidateCommitIDFormat() ok = %v, want %v", gotOk, tt.wantOk) + } + if gotErr != tt.wantErr { + t.Errorf("ValidateCommitIDFormat() err = %q, want %q", gotErr, tt.wantErr) + } + }) + } +} + +func TestShouldRequireRunManifest(t *testing.T) { + tests := []struct { + name string + status string + want bool + }{ + {"running", "running", true}, + {"completed", "completed", true}, + {"failed", "failed", true}, + {"queued", "queued", false}, + {"pending", "pending", false}, + {"cancelled", "cancelled", false}, + {"unknown", "unknown", false}, + {"empty", "", false}, + {"RUNNING uppercase", "RUNNING", true}, + {"Completed mixed", "Completed", true}, + {" Running with space", " running", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + task := &queue.Task{Status: tt.status} + got := helpers.ShouldRequireRunManifest(task) + if got != tt.want { + t.Errorf("ShouldRequireRunManifest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestExpectedRunManifestBucketForStatus(t *testing.T) { + tests := []struct { + name string + status string + wantBucket string + wantOk bool + }{ + {"queued", "queued", "pending", true}, + {"pending", "pending", "pending", true}, + {"running", "running", "running", true}, + {"completed", "completed", "finished", true}, + {"finished", "finished", "finished", true}, + {"failed", "failed", "failed", true}, + {"unknown", "unknown", "", false}, + {"empty", "", "", false}, + {"RUNNING uppercase", "RUNNING", "running", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotBucket, gotOk := helpers.ExpectedRunManifestBucketForStatus(tt.status) + if gotBucket != tt.wantBucket { + t.Errorf("ExpectedRunManifestBucketForStatus() bucket = %q, want %q", gotBucket, tt.wantBucket) + } + if gotOk != tt.wantOk { + t.Errorf("ExpectedRunManifestBucketForStatus() ok = %v, want %v", gotOk, tt.wantOk) + } + }) + } +} + +func TestValidateTaskIDMatch(t *testing.T) { + tests := []struct { + name string + rmTaskID string + expectedID string + wantOk bool + wantExpected string + wantActual string + }{ + { + name: "match", + rmTaskID: "task-123", + expectedID: "task-123", + wantOk: true, + wantExpected: "task-123", + wantActual: "task-123", + }, + { + name: "mismatch", + rmTaskID: "task-123", + expectedID: "task-456", + wantOk: false, + wantExpected: "task-456", + wantActual: "task-123", + }, + { + name: "empty rm task ID", + rmTaskID: "", + expectedID: "task-123", + wantOk: false, + wantExpected: "task-123", + wantActual: "", + }, + { + name: "whitespace trimmed - match", + rmTaskID: "task-123", + expectedID: "task-123", + wantOk: true, + wantExpected: "task-123", + wantActual: "task-123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := &manifest.RunManifest{TaskID: tt.rmTaskID} + got := helpers.ValidateTaskIDMatch(rm, tt.expectedID) + if got.OK != tt.wantOk { + t.Errorf("ValidateTaskIDMatch() OK = %v, want %v", got.OK, tt.wantOk) + } + if got.Expected != tt.wantExpected { + t.Errorf("ValidateTaskIDMatch() Expected = %q, want %q", got.Expected, tt.wantExpected) + } + if got.Actual != tt.wantActual { + t.Errorf("ValidateTaskIDMatch() Actual = %q, want %q", got.Actual, tt.wantActual) + } + }) + } +} + +func TestValidateCommitIDMatch(t *testing.T) { + tests := []struct { + name string + rmCommitID string + expectedID string + wantOk bool + wantExpected string + wantActual string + }{ + { + name: "both empty", + rmCommitID: "", + expectedID: "", + wantOk: true, + wantExpected: "", + wantActual: "", + }, + { + name: "match", + rmCommitID: "abc123", + expectedID: "abc123", + wantOk: true, + wantExpected: "abc123", + wantActual: "abc123", + }, + { + name: "mismatch", + rmCommitID: "abc123", + expectedID: "def456", + wantOk: false, + wantExpected: "def456", + wantActual: "abc123", + }, + { + name: "expected empty", + rmCommitID: "abc123", + expectedID: "", + wantOk: true, + wantExpected: "", + wantActual: "", + }, + { + name: "rm empty", + rmCommitID: "", + expectedID: "abc123", + wantOk: true, + wantExpected: "abc123", + wantActual: "", + }, + { + name: "whitespace trimmed", + rmCommitID: " abc123 ", + expectedID: "abc123", + wantOk: true, + wantExpected: "abc123", + wantActual: "abc123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.ValidateCommitIDMatch(tt.rmCommitID, tt.expectedID) + if got.OK != tt.wantOk { + t.Errorf("ValidateCommitIDMatch() OK = %v, want %v", got.OK, tt.wantOk) + } + if got.Expected != tt.wantExpected { + t.Errorf("ValidateCommitIDMatch() Expected = %q, want %q", got.Expected, tt.wantExpected) + } + if got.Actual != tt.wantActual { + t.Errorf("ValidateCommitIDMatch() Actual = %q, want %q", got.Actual, tt.wantActual) + } + }) + } +} + +func TestValidateDepsProvenance(t *testing.T) { + tests := []struct { + name string + wantName string + wantSHA string + gotName string + gotSHA string + wantOk bool + wantExpected string + wantActual string + }{ + { + name: "match", + wantName: "requirements.txt", + wantSHA: "abc123", + gotName: "requirements.txt", + gotSHA: "abc123", + wantOk: true, + wantExpected: "requirements.txt:abc123", + wantActual: "requirements.txt:abc123", + }, + { + name: "name mismatch", + wantName: "requirements.txt", + wantSHA: "abc123", + gotName: "Pipfile", + gotSHA: "abc123", + wantOk: false, + wantExpected: "requirements.txt:abc123", + wantActual: "Pipfile:abc123", + }, + { + name: "sha mismatch", + wantName: "requirements.txt", + wantSHA: "abc123", + gotName: "requirements.txt", + gotSHA: "def456", + wantOk: false, + wantExpected: "requirements.txt:abc123", + wantActual: "requirements.txt:def456", + }, + { + name: "want empty", + wantName: "", + wantSHA: "", + gotName: "requirements.txt", + gotSHA: "abc123", + wantOk: true, + wantExpected: "", + wantActual: "", + }, + { + name: "got empty", + wantName: "requirements.txt", + wantSHA: "abc123", + gotName: "", + gotSHA: "", + wantOk: true, + wantExpected: "", + wantActual: "", + }, + { + name: "both empty", + wantName: "", + wantSHA: "", + gotName: "", + gotSHA: "", + wantOk: true, + wantExpected: "", + wantActual: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.ValidateDepsProvenance(tt.wantName, tt.wantSHA, tt.gotName, tt.gotSHA) + if got.OK != tt.wantOk { + t.Errorf("ValidateDepsProvenance() OK = %v, want %v", got.OK, tt.wantOk) + } + if got.Expected != tt.wantExpected { + t.Errorf("ValidateDepsProvenance() Expected = %q, want %q", got.Expected, tt.wantExpected) + } + if got.Actual != tt.wantActual { + t.Errorf("ValidateDepsProvenance() Actual = %q, want %q", got.Actual, tt.wantActual) + } + }) + } +} + +func TestValidateSnapshotID(t *testing.T) { + tests := []struct { + name string + wantID string + gotID string + wantOk bool + wantExpected string + wantActual string + }{ + { + name: "match", + wantID: "snap-123", + gotID: "snap-123", + wantOk: true, + wantExpected: "snap-123", + wantActual: "snap-123", + }, + { + name: "mismatch", + wantID: "snap-123", + gotID: "snap-456", + wantOk: false, + wantExpected: "snap-123", + wantActual: "snap-456", + }, + { + name: "want empty", + wantID: "", + gotID: "snap-123", + wantOk: true, + wantExpected: "", + wantActual: "snap-123", + }, + { + name: "got empty", + wantID: "snap-123", + gotID: "", + wantOk: true, + wantExpected: "snap-123", + wantActual: "", + }, + { + name: "both empty", + wantID: "", + gotID: "", + wantOk: true, + wantExpected: "", + wantActual: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.ValidateSnapshotID(tt.wantID, tt.gotID) + if got.OK != tt.wantOk { + t.Errorf("ValidateSnapshotID() OK = %v, want %v", got.OK, tt.wantOk) + } + if got.Expected != tt.wantExpected { + t.Errorf("ValidateSnapshotID() Expected = %q, want %q", got.Expected, tt.wantExpected) + } + if got.Actual != tt.wantActual { + t.Errorf("ValidateSnapshotID() Actual = %q, want %q", got.Actual, tt.wantActual) + } + }) + } +} + +func TestValidateSnapshotSHA(t *testing.T) { + tests := []struct { + name string + wantSHA string + gotSHA string + wantOk bool + wantExpected string + wantActual string + }{ + { + name: "match", + wantSHA: "sha256:abc123", + gotSHA: "sha256:abc123", + wantOk: true, + wantExpected: "sha256:abc123", + wantActual: "sha256:abc123", + }, + { + name: "mismatch", + wantSHA: "sha256:abc123", + gotSHA: "sha256:def456", + wantOk: false, + wantExpected: "sha256:abc123", + wantActual: "sha256:def456", + }, + { + name: "want empty", + wantSHA: "", + gotSHA: "sha256:abc123", + wantOk: true, + wantExpected: "", + wantActual: "sha256:abc123", + }, + { + name: "got empty", + wantSHA: "sha256:abc123", + gotSHA: "", + wantOk: true, + wantExpected: "sha256:abc123", + wantActual: "", + }, + { + name: "both empty", + wantSHA: "", + gotSHA: "", + wantOk: true, + wantExpected: "", + wantActual: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := helpers.ValidateSnapshotSHA(tt.wantSHA, tt.gotSHA) + if got.OK != tt.wantOk { + t.Errorf("ValidateSnapshotSHA() OK = %v, want %v", got.OK, tt.wantOk) + } + if got.Expected != tt.wantExpected { + t.Errorf("ValidateSnapshotSHA() Expected = %q, want %q", got.Expected, tt.wantExpected) + } + if got.Actual != tt.wantActual { + t.Errorf("ValidateSnapshotSHA() Actual = %q, want %q", got.Actual, tt.wantActual) + } + }) + } +}