test: add comprehensive test coverage and command improvements

- Add logs and debug end-to-end tests
- Add test helper utilities
- Improve test fixtures and templates
- Update API server and config lint commands
- Add multi-user database initialization
This commit is contained in:
Jeremie Fraeys 2026-02-16 20:38:15 -05:00
parent b05470b30a
commit 7305e2bc21
No known key found for this signature in database
22 changed files with 3643 additions and 7 deletions

View file

@ -0,0 +1,125 @@
package benchmarks
import (
"os"
"path/filepath"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker"
)
// BenchmarkArtifactScanGo profiles Go filepath.WalkDir implementation
func BenchmarkArtifactScanGo(b *testing.B) {
tmpDir := b.TempDir()
// Create test artifact structure
createTestArtifacts(b, tmpDir, 100)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := worker.ScanArtifacts(tmpDir)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkArtifactScanNative profiles C++ platform-optimized traversal
// Uses: fts on BSD, getdents64 on Linux, getattrlistbulk on macOS
func BenchmarkArtifactScanNative(b *testing.B) {
tmpDir := b.TempDir()
// Create test artifact structure
createTestArtifacts(b, tmpDir, 100)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := worker.ScanArtifactsNative(tmpDir)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkArtifactScanLarge tests with many files
func BenchmarkArtifactScanLarge(b *testing.B) {
tmpDir := b.TempDir()
// Create 1000 test files
createTestArtifacts(b, tmpDir, 1000)
b.Run("Go", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := worker.ScanArtifacts(tmpDir)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("Native", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := worker.ScanArtifactsNative(tmpDir)
if err != nil {
b.Fatal(err)
}
}
})
}
// createTestArtifacts creates a directory structure with test files
func createTestArtifacts(b testing.TB, root string, count int) {
b.Helper()
// Create nested directories
dirs := []string{
"outputs",
"outputs/models",
"outputs/checkpoints",
"logs",
"data",
}
for _, dir := range dirs {
if err := os.MkdirAll(filepath.Join(root, dir), 0750); err != nil {
b.Fatal(err)
}
}
// Create test files
for i := 0; i < count; i++ {
var path string
switch i % 5 {
case 0:
path = filepath.Join(root, "outputs", "model_"+string(rune('0'+i%10))+".pt")
case 1:
path = filepath.Join(root, "outputs", "models", "checkpoint_"+string(rune('0'+i%10))+".ckpt")
case 2:
path = filepath.Join(root, "outputs", "checkpoints", "epoch_"+string(rune('0'+i%10))+".pt")
case 3:
path = filepath.Join(root, "logs", "train_"+string(rune('0'+i%10))+".log")
case 4:
path = filepath.Join(root, "data", "batch_"+string(rune('0'+i%10))+".npy")
}
data := make([]byte, 1024*(i%10+1)) // Varying sizes 1KB-10KB
for j := range data {
data[j] = byte(i + j%256)
}
if err := os.WriteFile(path, data, 0640); err != nil {
b.Fatal(err)
}
}
// Create files that should be excluded
os.WriteFile(filepath.Join(root, "run_manifest.json"), []byte("{}"), 0640)
os.MkdirAll(filepath.Join(root, "code"), 0750)
os.WriteFile(filepath.Join(root, "code", "script.py"), []byte("# test"), 0640)
}

View file

@ -0,0 +1,175 @@
package benchmarks
import (
"testing"
"gopkg.in/yaml.v3"
)
// Sample server config YAML for benchmarking
const sampleServerConfig = `
base_path: /api/v1
data_dir: /data/fetchml
auth:
type: jwt
secret: "super-secret-key-for-benchmarking-only"
token_ttl: 3600
server:
address: :8080
tls:
enabled: true
cert_file: /etc/ssl/certs/server.crt
key_file: /etc/ssl/private/server.key
security:
max_request_size: 1048576
rate_limit: 100
cors_origins:
- https://app.fetchml.com
- https://admin.fetchml.com
queue:
backend: redis
sqlite_path: /data/queue.db
filesystem_path: /data/queue
fallback_to_filesystem: true
redis:
addr: localhost:6379
password: ""
db: 0
pool_size: 50
database:
driver: postgres
dsn: postgres://user:pass@localhost/fetchml?sslmode=disable
max_connections: 100
logging:
level: info
format: json
output: stdout
resources:
max_cpu_per_task: 8
max_memory_per_task: 32
max_gpu_per_task: 1
monitoring:
enabled: true
prometheus_port: 9090
`
type BenchmarkServerConfig struct {
BasePath string `yaml:"base_path"`
DataDir string `yaml:"data_dir"`
Auth BenchmarkAuthConfig `yaml:"auth"`
Server BenchmarkServerSection `yaml:"server"`
Security BenchmarkSecurityConfig `yaml:"security"`
Queue BenchmarkQueueConfig `yaml:"queue"`
Redis BenchmarkRedisConfig `yaml:"redis"`
Database BenchmarkDatabaseConfig `yaml:"database"`
Logging BenchmarkLoggingConfig `yaml:"logging"`
Resources BenchmarkResourceConfig `yaml:"resources"`
Monitoring BenchmarkMonitoringConfig `yaml:"monitoring"`
}
type BenchmarkAuthConfig struct {
Type string `yaml:"type"`
Secret string `yaml:"secret"`
TokenTTL int `yaml:"token_ttl"`
}
type BenchmarkServerSection struct {
Address string `yaml:"address"`
TLS BenchmarkTLSConfig `yaml:"tls"`
}
type BenchmarkTLSConfig struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
}
type BenchmarkSecurityConfig struct {
MaxRequestSize int `yaml:"max_request_size"`
RateLimit int `yaml:"rate_limit"`
CORSOrigins []string `yaml:"cors_origins"`
}
type BenchmarkQueueConfig struct {
Backend string `yaml:"backend"`
SQLitePath string `yaml:"sqlite_path"`
FilesystemPath string `yaml:"filesystem_path"`
FallbackToFilesystem bool `yaml:"fallback_to_filesystem"`
}
type BenchmarkRedisConfig struct {
Addr string `yaml:"addr"`
Password string `yaml:"password"`
DB int `yaml:"db"`
PoolSize int `yaml:"pool_size"`
}
type BenchmarkDatabaseConfig struct {
Driver string `yaml:"driver"`
DSN string `yaml:"dsn"`
MaxConnections int `yaml:"max_connections"`
}
type BenchmarkLoggingConfig struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
Output string `yaml:"output"`
}
type BenchmarkResourceConfig struct {
MaxCPUPerTask int `yaml:"max_cpu_per_task"`
MaxMemoryPerTask int `yaml:"max_memory_per_task"`
MaxGPUPerTask int `yaml:"max_gpu_per_task"`
}
type BenchmarkMonitoringConfig struct {
Enabled bool `yaml:"enabled"`
PrometheusPort int `yaml:"prometheus_port"`
}
// BenchmarkConfigYAMLUnmarshal profiles YAML config parsing
// Tier 3 C++ candidate: Fast binary config format
func BenchmarkConfigYAMLUnmarshal(b *testing.B) {
data := []byte(sampleServerConfig)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
var cfg BenchmarkServerConfig
err := yaml.Unmarshal(data, &cfg)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkConfigYAMLUnmarshalLarge profiles large config parsing
func BenchmarkConfigYAMLUnmarshalLarge(b *testing.B) {
// Create larger config with more nested data
largeConfig := sampleServerConfig + `
extra_section:
items:
- id: item1
name: "First Item"
value: 100
- id: item2
name: "Second Item"
value: 200
- id: item3
name: "Third Item"
value: 300
`
data := []byte(largeConfig)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
var cfg BenchmarkServerConfig
err := yaml.Unmarshal(data, &cfg)
if err != nil {
b.Fatal(err)
}
}
}

View file

@ -0,0 +1,185 @@
package benchmarks
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// BenchmarkTaskJSONMarshal profiles Task serialization to JSON
// Tier 3 C++ candidate: Zero-copy binary serialization
func BenchmarkTaskJSONMarshal(b *testing.B) {
task := createTestTask()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := json.Marshal(task)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkTaskJSONUnmarshal profiles Task deserialization from JSON
// Tier 3 C++ candidate: Memory-mapped binary format
func BenchmarkTaskJSONUnmarshal(b *testing.B) {
task := createTestTask()
data, err := json.Marshal(task)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
var t queue.Task
err := json.Unmarshal(data, &t)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkTaskJSONRoundTrip profiles full serialize/deserialize cycle
func BenchmarkTaskJSONRoundTrip(b *testing.B) {
task := createTestTask()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
data, err := json.Marshal(task)
if err != nil {
b.Fatal(err)
}
var t queue.Task
err = json.Unmarshal(data, &t)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkPrewarmStateJSONMarshal profiles PrewarmState serialization
// Used in worker prewarm coordination
func BenchmarkPrewarmStateJSONMarshal(b *testing.B) {
state := queue.PrewarmState{
WorkerID: "worker-12345",
TaskID: "task-67890",
SnapshotID: "snap-abc123",
StartedAt: time.Now().UTC().Format(time.RFC3339),
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
Phase: "running",
EnvImage: "fetchml/python:3.11",
DatasetCnt: 3,
EnvHit: 5,
EnvMiss: 1,
EnvBuilt: 2,
EnvTimeNs: 1500000000,
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := json.Marshal(state)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkPrewarmStateJSONUnmarshal profiles PrewarmState deserialization
func BenchmarkPrewarmStateJSONUnmarshal(b *testing.B) {
state := queue.PrewarmState{
WorkerID: "worker-12345",
TaskID: "task-67890",
SnapshotID: "snap-abc123",
StartedAt: time.Now().UTC().Format(time.RFC3339),
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
Phase: "running",
EnvImage: "fetchml/python:3.11",
DatasetCnt: 3,
EnvHit: 5,
EnvMiss: 1,
EnvBuilt: 2,
EnvTimeNs: 1500000000,
}
data, err := json.Marshal(state)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
var s queue.PrewarmState
err := json.Unmarshal(data, &s)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkTaskBatchJSONMarshal profiles batch task serialization
// Critical for queue index operations with many tasks
func BenchmarkTaskBatchJSONMarshal(b *testing.B) {
tasks := make([]queue.Task, 100)
for i := 0; i < 100; i++ {
tasks[i] = createTestTaskWithID(i)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, task := range tasks {
_, err := json.Marshal(task)
if err != nil {
b.Fatal(err)
}
}
}
}
// Helper to create a test task
func createTestTask() queue.Task {
return createTestTaskWithID(0)
}
func createTestTaskWithID(id int) queue.Task {
now := time.Now()
return queue.Task{
ID: fmt.Sprintf("task-550e8400-e29b-41d4-a716-44665544%03d", id),
JobName: "ml-training-job",
Args: "--epochs=100 --batch_size=32 --learning_rate=0.001",
Status: "queued",
Priority: 1000,
CreatedAt: now,
DatasetSpecs: []queue.DatasetSpec{
{Name: "training-data", Version: "v1.2.3", Checksum: "sha256:abc123", URI: "s3://bucket/data/train"},
{Name: "validation-data", Version: "v1.0.0", Checksum: "sha256:def456", URI: "s3://bucket/data/val"},
},
Datasets: []string{"dataset1", "dataset2"},
Metadata: map[string]string{
"experiment_id": "exp-12345",
"user_email": "user@example.com",
"model_type": "transformer",
},
CPU: 4,
MemoryGB: 16,
GPU: 1,
GPUMemory: "24GB",
UserID: "user-550e8400-e29b-41d4-a716",
Username: "ml_researcher",
CreatedBy: "admin",
MaxRetries: 3,
}
}

View file

@ -0,0 +1,278 @@
package benchmarks
import (
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
)
// BenchmarkResolveWorkspacePath profiles path canonicalization hot path
// Tier 2 C++ candidate: SIMD string operations for path validation
func BenchmarkResolveWorkspacePath(b *testing.B) {
testPaths := []string{
"my-workspace",
"./relative/path",
"/absolute/path/to/workspace",
" trimmed-path ",
"deep/nested/workspace/name",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
path := testPaths[i%len(testPaths)]
// Simulate the resolveWorkspacePath logic inline
ws := strings.TrimSpace(path)
clean := filepath.Clean(ws)
if !filepath.IsAbs(clean) {
clean = filepath.Join("/data/active/workspaces", clean)
}
_ = clean
}
}
// BenchmarkStringTrimSpace profiles strings.TrimSpace usage
// Heavy usage in service_manager.go (55+ calls)
func BenchmarkStringTrimSpace(b *testing.B) {
testInputs := []string{
" package-name ",
"\t\n trimmed \r\n",
"normal",
" ",
"",
"pkg1==1.0.0",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = strings.TrimSpace(testInputs[i%len(testInputs)])
}
}
// BenchmarkStringSplit profiles strings.Split for package parsing
// Used in parsePipList, parseCondaList, splitPackageList
func BenchmarkStringSplit(b *testing.B) {
testInputs := []string{
"numpy,pandas,scikit-learn,tensorflow",
"package==1.0.0",
"single",
"",
"a,b,c,d,e,f,g,h,i,j",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = strings.Split(testInputs[i%len(testInputs)], ",")
}
}
// BenchmarkStringHasPrefix profiles strings.HasPrefix
// Used for path traversal detection and comment filtering
func BenchmarkStringHasPrefix(b *testing.B) {
testInputs := []string{
"../traversal",
"normal/path",
"# comment",
"package==1.0",
"..",
}
prefixes := []string{"..", "#", "package"}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = strings.HasPrefix(testInputs[i%len(testInputs)], prefixes[i%len(prefixes)])
}
}
// BenchmarkStringEqualFold profiles case-insensitive comparison
// Used for blocked package matching
func BenchmarkStringEqualFold(b *testing.B) {
testPairs := []struct {
a, b string
}{
{"numpy", "NUMPY"},
{"requests", "requests"},
{"TensorFlow", "tensorflow"},
{"scikit-learn", "SCIKIT-LEARN"},
{"off", "OFF"},
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pair := testPairs[i%len(testPairs)]
_ = strings.EqualFold(pair.a, pair.b)
}
}
// BenchmarkFilepathClean profiles path canonicalization
// Used in resolveWorkspacePath and other path operations
func BenchmarkFilepathClean(b *testing.B) {
testPaths := []string{
"./relative/../path",
"/absolute//double//slashes",
"../../../etc/passwd",
"normal/path",
".",
"..",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = filepath.Clean(testPaths[i%len(testPaths)])
}
}
// BenchmarkFilepathJoin profiles path joining
// Used extensively for building workspace, trash, state paths
func BenchmarkFilepathJoin(b *testing.B) {
components := [][]string{
{"/data", "active", "workspaces"},
{"/tmp", "trash", "jupyter"},
{"state", "fetch_ml_jupyter_workspaces.json"},
{"workspace", "subdir", "file.txt"},
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parts := components[i%len(components)]
_ = filepath.Join(parts...)
}
}
// BenchmarkStrconvAtoi profiles string to int conversion
// Used for port parsing, resource limits
func BenchmarkStrconvAtoi(b *testing.B) {
testInputs := []string{
"8888",
"0",
"65535",
"-1",
"999999",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = strconv.Atoi(testInputs[i%len(testInputs)])
}
}
// BenchmarkTimeFormat profiles timestamp formatting
// Used for trash destination naming
func BenchmarkTimeFormat(b *testing.B) {
now := time.Now().UTC()
format := "20060102_150405"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = now.Format(format)
}
}
// BenchmarkSprintfConcat profiles string concatenation
// Used for building destination names
func BenchmarkSprintfConcat(b *testing.B) {
names := []string{"workspace1", "my-project", "test-ws"}
timestamps := []string{"20240115_143022", "20240212_183045", "20240301_120000"}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = names[i%len(names)] + "_" + timestamps[i%len(timestamps)]
}
}
// BenchmarkPackageListParsing profiles full package list parsing pipeline
// Combines Split, TrimSpace, and filtering - mirrors splitPackageList
func BenchmarkPackageListParsing(b *testing.B) {
testInputs := []string{
"numpy, pandas, scikit-learn, tensorflow",
" requests , urllib3 , certifi ",
"single-package",
"",
"a, b, c, d, e, f, g",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
input := testInputs[i%len(testInputs)]
parts := strings.Split(input, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
}
}
// BenchmarkEnvLookup profiles environment variable lookups
// Used for FETCHML_JUPYTER_* configuration
func BenchmarkEnvLookup(b *testing.B) {
// Pre-set test env vars
os.Setenv("TEST_JUPYTER_STATE_DIR", "/tmp/jupyter-state")
os.Setenv("TEST_JUPYTER_WORKSPACE_BASE", "/tmp/workspaces")
keys := []string{
"TEST_JUPYTER_STATE_DIR",
"TEST_JUPYTER_WORKSPACE_BASE",
"NONEXISTENT_VAR",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = os.LookupEnv(keys[i%len(keys)])
}
}
// BenchmarkCombinedJupyterHotPath profiles typical service manager operation
// Combines multiple string/path operations as they occur in real usage
func BenchmarkCombinedJupyterHotPath(b *testing.B) {
testWorkspaces := []string{
" my-project ",
"./relative-ws",
"/absolute/path/to/workspace",
"deep/nested/name",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Simulate resolveWorkspacePath + path building
ws := strings.TrimSpace(testWorkspaces[i%len(testWorkspaces)])
clean := filepath.Clean(ws)
if !filepath.IsAbs(clean) {
clean = filepath.Join("/data/active/workspaces", clean)
}
// Simulate trash path building
ts := time.Now().UTC().Format("20060102_150405")
destName := ws + "_" + ts
_ = filepath.Join("/tmp/trash/jupyter", destName)
}
}

View file

@ -0,0 +1,84 @@
package benchmarks
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// BenchmarkLogSanitizeMessage profiles log message sanitization.
// This is a Tier 1 C++ candidate because:
// - Regex matching is CPU-intensive
// - High volume log pipelines process thousands of messages/sec
// - C++ can use Hyperscan/RE2 for parallel regex matching
// Expected speedup: 3-5x for high-volume logging
func BenchmarkLogSanitizeMessage(b *testing.B) {
// Test messages with various sensitive data patterns
messages := []string{
"User login successful with api_key=abc123def45678901234567890abcdef",
"JWT token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0C12LVE",
"Redis connection: redis://:secretpassword123@localhost:6379/0",
"User admin password=supersecret123 trying to access resource",
"Normal log message without any sensitive data to process",
"API call with key=fedcba9876543210fedcba9876543210 and secret=shh123",
"Connection string: redis://:another_secret@redis.example.com:6380",
"Authentication token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.test.signature",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
msg := messages[i%len(messages)]
_ = logging.SanitizeLogMessage(msg)
}
}
// BenchmarkLogSanitizeArgs profiles structured log argument sanitization.
// This processes key-value pairs looking for sensitive field names.
func BenchmarkLogSanitizeArgs(b *testing.B) {
// Simulate typical structured log arguments
args := []any{
"user_id", "user123",
"password", "secret123",
"api_key", "abcdef1234567890",
"action", "login",
"secret_token", "eyJhbGci...",
"request_id", "req-12345",
"database_url", "redis://:password@localhost:6379",
"timestamp", "2024-01-01T00:00:00Z",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = logging.SanitizeArgs(args)
}
}
// BenchmarkLogSanitizeHighVolume simulates high-throughput logging scenario
// with many messages per second (e.g., 10K+ messages/sec).
func BenchmarkLogSanitizeHighVolume(b *testing.B) {
// Mix of message types
testMessages := []string{
"API request: POST /api/v1/jobs with api_key=abcdef1234567890abcdef1234567890",
"User user123 authenticated with token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature",
"Database connection established to redis://:redact_me@localhost:6379",
"Job execution started for job_name=test_job",
"Error processing request: password=wrong_secret provided",
"Metrics: cpu=45%, memory=2.5GB, gpu=0%",
"Config loaded: secret_key=hidden_value123",
"Webhook received with authorization=Bearer eyJ0eXAiOiJKV1Qi.test.sig",
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Simulate batch processing of multiple messages
for _, msg := range testMessages {
_ = logging.SanitizeLogMessage(msg)
}
}
}

View file

@ -0,0 +1,40 @@
package benchmarks
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// BenchmarkNativeQueueBasic tests basic native queue operations
func BenchmarkNativeQueueBasic(b *testing.B) {
tmpDir := b.TempDir()
// Only run if native libs available
if !queue.UseNativeQueue {
b.Skip("Native queue not enabled (set FETCHML_NATIVE_LIBS=1)")
}
q, err := queue.NewNativeQueue(tmpDir)
if err != nil {
b.Fatal(err)
}
defer q.Close()
// Test single add
task := &queue.Task{
ID: "test-1",
JobName: "test-job",
Priority: 100,
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
task.ID = "test-" + string(rune('0'+i%10))
if err := q.AddTask(task); err != nil {
b.Fatal(err)
}
}
}

View file

@ -0,0 +1,108 @@
package benchmarks
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// BenchmarkNativeQueueRebuildIndex profiles the native binary queue index.
// Tier 1 C++ candidate: binary format vs JSON
// Expected: 5x speedup, 99% allocation reduction
func BenchmarkNativeQueueRebuildIndex(b *testing.B) {
tmpDir := b.TempDir()
q, err := queue.NewNativeQueue(tmpDir)
if err != nil {
b.Fatal(err)
}
defer q.Close()
// Seed with tasks
for i := 0; i < 100; i++ {
task := &queue.Task{
ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)),
JobName: "job-" + string(rune('0'+i/10)),
Priority: int64(100 - i),
}
if err := q.AddTask(task); err != nil {
b.Fatal(err)
}
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark just the add (native uses binary index, no JSON rebuild)
for i := 0; i < b.N; i++ {
task := &queue.Task{
ID: "bench-task-" + string(rune('0'+i%10)),
JobName: "bench-job",
Priority: int64(i),
}
if err := q.AddTask(task); err != nil {
b.Fatal(err)
}
}
}
// BenchmarkNativeQueueClaimNext profiles task claiming from binary heap
func BenchmarkNativeQueueClaimNext(b *testing.B) {
tmpDir := b.TempDir()
q, err := queue.NewNativeQueue(tmpDir)
if err != nil {
b.Fatal(err)
}
defer q.Close()
// Seed with tasks
for i := 0; i < 100; i++ {
task := &queue.Task{
ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)),
JobName: "job-" + string(rune('0'+i/10)),
Priority: int64(100 - i),
}
if err := q.AddTask(task); err != nil {
b.Fatal(err)
}
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Binary heap pop - no JSON parsing
_, _ = q.PeekNextTask()
}
}
// BenchmarkNativeQueueGetAllTasks profiles full task scan from binary index
func BenchmarkNativeQueueGetAllTasks(b *testing.B) {
tmpDir := b.TempDir()
q, err := queue.NewNativeQueue(tmpDir)
if err != nil {
b.Fatal(err)
}
defer q.Close()
// Seed with tasks
for i := 0; i < 100; i++ {
task := &queue.Task{
ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)),
JobName: "job-" + string(rune('0'+i/10)),
Priority: int64(100 - i),
}
if err := q.AddTask(task); err != nil {
b.Fatal(err)
}
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := q.GetAllTasks()
if err != nil {
b.Fatal(err)
}
}
}

View file

@ -0,0 +1,189 @@
package benchmarks
import (
"archive/tar"
"bytes"
"compress/gzip"
"os"
"path/filepath"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker"
)
// BenchmarkExtractTarGzGo profiles Go sequential tar.gz extraction
func BenchmarkExtractTarGzGo(b *testing.B) {
tmpDir := b.TempDir()
archivePath := createStreamingTestArchive(b, tmpDir, 100, 1024) // 100 files, 1KB each
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
dstDir := filepath.Join(tmpDir, "extract_go_"+string(rune('0'+i%10)))
if err := os.MkdirAll(dstDir, 0750); err != nil {
b.Fatal(err)
}
if err := worker.ExtractTarGz(archivePath, dstDir); err != nil {
b.Fatal(err)
}
}
}
// BenchmarkExtractTarGzNative profiles C++ parallel decompression
// Uses: mmap + thread pool + O_DIRECT for large files
func BenchmarkExtractTarGzNative(b *testing.B) {
tmpDir := b.TempDir()
archivePath := createStreamingTestArchive(b, tmpDir, 100, 1024)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
dstDir := filepath.Join(tmpDir, "extract_native_"+string(rune('0'+i%10)))
if err := os.MkdirAll(dstDir, 0750); err != nil {
b.Fatal(err)
}
if err := worker.ExtractTarGzNative(archivePath, dstDir); err != nil {
b.Fatal(err)
}
}
}
// BenchmarkExtractTarGzSizes tests different archive sizes
func BenchmarkExtractTarGzSizes(b *testing.B) {
tmpDir := b.TempDir()
// Small: 10 files, 1KB each (~10KB compressed)
b.Run("Small", func(b *testing.B) {
archivePath := createStreamingTestArchive(b, tmpDir, 10, 1024)
benchmarkBoth(b, archivePath, tmpDir)
})
// Medium: 100 files, 10KB each (~1MB compressed)
b.Run("Medium", func(b *testing.B) {
archivePath := createStreamingTestArchive(b, tmpDir, 100, 10240)
benchmarkBoth(b, archivePath, tmpDir)
})
// Large: 50 files, 100KB each (~5MB compressed)
b.Run("Large", func(b *testing.B) {
archivePath := createStreamingTestArchive(b, tmpDir, 50, 102400)
benchmarkBoth(b, archivePath, tmpDir)
})
}
func benchmarkBoth(b *testing.B, archivePath, tmpDir string) {
b.Run("Go", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
dstDir := filepath.Join(tmpDir, "go_"+string(rune('0'+i%10)))
if err := os.MkdirAll(dstDir, 0750); err != nil {
b.Fatal(err)
}
if err := worker.ExtractTarGz(archivePath, dstDir); err != nil {
b.Fatal(err)
}
}
})
b.Run("Native", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
dstDir := filepath.Join(tmpDir, "native_"+string(rune('0'+i%10)))
if err := os.MkdirAll(dstDir, 0750); err != nil {
b.Fatal(err)
}
if err := worker.ExtractTarGzNative(archivePath, dstDir); err != nil {
b.Fatal(err)
}
}
})
}
// createStreamingTestArchive creates a tar.gz archive with test files for streaming benchmarks
func createStreamingTestArchive(b testing.TB, tmpDir string, numFiles, fileSize int) string {
b.Helper()
// Create data directory with test files
dataDir := filepath.Join(tmpDir, "data_"+string(rune('0'+numFiles/10)))
if err := os.MkdirAll(dataDir, 0750); err != nil {
b.Fatal(err)
}
for i := 0; i < numFiles; i++ {
subdir := filepath.Join(dataDir, "subdir_"+string(rune('0'+i%5)))
if err := os.MkdirAll(subdir, 0750); err != nil {
b.Fatal(err)
}
filename := filepath.Join(subdir, "file_"+string(rune('0'+i/10))+".bin")
data := make([]byte, fileSize)
for j := range data {
data[j] = byte((i + j) % 256)
}
if err := os.WriteFile(filename, data, 0640); err != nil {
b.Fatal(err)
}
}
// Create tar.gz archive
archivePath := filepath.Join(tmpDir, "test_"+string(rune('0'+numFiles/10))+".tar.gz")
if err := createTarGzFromDir(dataDir, archivePath); err != nil {
b.Fatal(err)
}
return archivePath
}
// createTarGzFromDir creates a tar.gz archive from a directory
func createTarGzFromDir(srcDir, dstPath string) error {
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
tw := tar.NewWriter(gw)
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
rel, err := filepath.Rel(srcDir, path)
if err != nil {
return err
}
hdr, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
hdr.Name = rel
if err := tw.WriteHeader(hdr); err != nil {
return err
}
if !info.IsDir() {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if _, err := tw.Write(data); err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
if err := tw.Close(); err != nil {
return err
}
if err := gw.Close(); err != nil {
return err
}
return os.WriteFile(dstPath, buf.Bytes(), 0640)
}

View file

@ -14,6 +14,14 @@ import (
// ChaosTestSuite tests system resilience under various failure conditions
func TestChaosTestSuite(t *testing.T) {
// Check Redis availability at suite level and warn if not available
quickRedis := redis.NewClient(&redis.Options{Addr: "localhost:6379", DB: 6})
if err := quickRedis.Ping(context.Background()).Err(); err != nil {
t.Logf("WARNING: Redis not available at localhost:6379 - chaos tests will be skipped")
t.Logf(" To run these tests, start Redis: redis-server --port 6379")
}
quickRedis.Close()
// Tests that intentionally close/corrupt connections get their own resources
// to prevent cascading failures to subsequent subtests
@ -69,6 +77,7 @@ func TestChaosTestSuite(t *testing.T) {
rdb := setupChaosRedis(t)
if rdb == nil {
t.Skip("Redis not available for chaos tests")
return
}
defer func() { _ = rdb.Close() }()
@ -473,6 +482,7 @@ func setupChaosRedis(t *testing.T) *redis.Client {
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Logf("Skipping chaos test - Redis not available: %v", err)
t.Skipf("Redis not available for chaos tests: %v", err)
return nil
}

View file

@ -0,0 +1,50 @@
---
# Docker Compose configuration for logs and debug E2E tests
# Simplified version using pre-built golang image with source mount
services:
redis:
image: redis:7-alpine
ports:
- "6380:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5
api-server:
image: golang:1.25-bookworm
working_dir: /app
command: >
sh -c "
go build -o api-server ./cmd/api-server/main.go &&
./api-server --config /app/configs/api/dev.yaml
"
ports:
- "9102:9101"
environment:
- LOG_LEVEL=debug
- REDIS_ADDR=redis:6379
- FETCHML_NATIVE_LIBS=0
volumes:
- ../../:/app
- api-logs:/logs
- api-experiments:/data/experiments
- api-active:/data/active
- go-mod-cache:/go/pkg/mod
depends_on:
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "wget", "-q", "--spider", "http://localhost:9101/health"]
interval: 5s
timeout: 3s
retries: 10
start_period: 30s
volumes:
api-logs:
api-experiments:
api-active:
go-mod-cache:

View file

@ -0,0 +1,590 @@
package tests
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
const (
logsDebugComposeFile = "docker-compose.logs-debug.yml"
apiPort = "9102" // Use docker-compose port
apiHost = "localhost"
)
// TestLogsDebugE2E tests the logs and debug WebSocket API end-to-end
func TestLogsDebugE2E(t *testing.T) {
if os.Getenv("FETCH_ML_E2E_LOGS_DEBUG") != "1" {
t.Skip("Skipping LogsDebugE2E (set FETCH_ML_E2E_LOGS_DEBUG=1 to enable)")
}
composeFile := filepath.Join(".", logsDebugComposeFile)
if _, err := os.Stat(composeFile); os.IsNotExist(err) {
t.Skipf("Docker compose file not found: %s", composeFile)
}
// Ensure Docker is available
if _, err := exec.LookPath("docker"); err != nil {
t.Skip("Docker not found in PATH")
}
if _, err := exec.LookPath("docker-compose"); err != nil {
t.Skip("docker-compose not found in PATH")
}
t.Run("SetupAndLogsAPI", func(t *testing.T) {
testLogsAPIWithDockerCompose(t, composeFile)
})
t.Run("SetupAndDebugAPI", func(t *testing.T) {
testDebugAPIWithDockerCompose(t, composeFile)
})
}
// testLogsAPIWithDockerCompose tests the logs API using Docker Compose
func testLogsAPIWithDockerCompose(t *testing.T, composeFile string) {
// Cleanup any existing containers
cleanupDockerCompose(t, composeFile)
// Start services
t.Log("Starting Docker Compose services for logs API test...")
upCmd := exec.CommandContext(context.Background(),
"docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "up", "-d", "--build")
upOutput, err := upCmd.CombinedOutput()
if err != nil {
t.Fatalf("Failed to start Docker Compose: %v\nOutput: %s", err, string(upOutput))
}
// Cleanup after test
defer func() {
cleanupDockerCompose(t, composeFile)
}()
// Wait for services to be healthy
t.Log("Waiting for services to be healthy...")
if !waitForServicesHealthy(t, "logsdebug-e2e", 90*time.Second) {
t.Fatal("Services failed to become healthy")
}
// Add delay to ensure server is fully ready for WebSocket connections
t.Log("Waiting additional 3 seconds for WebSocket to be ready...")
time.Sleep(3 * time.Second)
// Test logs API
testTargetID := "ab"
testGetLogsAPI(t, testTargetID)
testStreamLogsAPI(t, testTargetID)
}
// testDebugAPIWithDockerCompose tests the debug API using Docker Compose
func testDebugAPIWithDockerCompose(t *testing.T, composeFile string) {
// Cleanup any existing containers
cleanupDockerCompose(t, composeFile)
// Start services
t.Log("Starting Docker Compose services for debug API test...")
upCmd := exec.CommandContext(context.Background(),
"docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "up", "-d", "--build")
upOutput, err := upCmd.CombinedOutput()
if err != nil {
t.Fatalf("Failed to start Docker Compose: %v\nOutput: %s", err, string(upOutput))
}
// Cleanup after test
defer func() {
cleanupDockerCompose(t, composeFile)
}()
// Wait for services to be healthy
t.Log("Waiting for services to be healthy...")
if !waitForServicesHealthy(t, "logsdebug-e2e", 90*time.Second) {
t.Fatal("Services failed to become healthy")
}
// Add delay to ensure server is fully ready for WebSocket connections
t.Log("Waiting additional 3 seconds for WebSocket to be ready...")
time.Sleep(3 * time.Second)
testTargetID := "test-job-456"
testAttachDebugAPI(t, testTargetID, "interactive")
testAttachDebugAPI(t, testTargetID, "gdb")
testAttachDebugAPI(t, testTargetID, "pdb")
}
// connectWebSocketWithRetry attempts to connect with exponential backoff
func connectWebSocketWithRetry(t *testing.T, wsURL string, maxRetries int) (*websocket.Conn, *http.Response, error) {
var conn *websocket.Conn
var resp *http.Response
var err error
// Create dialer with nil headers (like existing working test)
dialer := websocket.DefaultDialer
dialer.HandshakeTimeout = 10 * time.Second
for i := 0; i < maxRetries; i++ {
conn, resp, err = dialer.Dial(wsURL, nil)
if err == nil {
return conn, resp, nil
}
// Log response details for debugging
if resp != nil {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
t.Logf("WebSocket connection attempt %d/%d failed: status=%d, body=%s", i+1, maxRetries, resp.StatusCode, string(body))
} else {
t.Logf("WebSocket connection attempt %d/%d failed: %v", i+1, maxRetries, err)
}
if i < maxRetries-1 {
delay := time.Duration(i+1) * 500 * time.Millisecond
t.Logf("Waiting %v before retry...", delay)
time.Sleep(delay)
}
}
return nil, resp, err
}
func testGetLogsAPI(t *testing.T, targetID string) {
t.Logf("Testing get_logs API with target: %s", targetID)
// First verify HTTP connectivity
healthURL := fmt.Sprintf("http://%s:%s/health", apiHost, apiPort)
resp, err := http.Get(healthURL)
if err != nil {
t.Fatalf("HTTP health check failed: %v", err)
}
resp.Body.Close()
t.Logf("HTTP health check passed: status=%d", resp.StatusCode)
// Test WebSocket endpoint with HTTP GET (should return 400 or upgrade required)
wsHTTPURL := fmt.Sprintf("http://%s:%s/ws", apiHost, apiPort)
resp, err = http.Get(wsHTTPURL)
if err != nil {
t.Logf("HTTP GET to /ws failed: %v", err)
} else {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
t.Logf("HTTP GET /ws: status=%d, body=%s", resp.StatusCode, string(body))
}
wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort)
// Connect to WebSocket with retry
conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
t.Fatalf("Failed to connect to WebSocket after retries: %v", err)
}
defer conn.Close()
// Build binary message for get_logs
// [opcode:1][api_key_hash:16][target_id_len:1][target_id:var]
opcode := byte(0x20) // OpcodeGetLogs
apiKeyHash := make([]byte, 16) // Zero-filled for testing (no auth)
targetIDBytes := []byte(targetID)
message := []byte{opcode}
message = append(message, apiKeyHash...)
message = append(message, byte(len(targetIDBytes)))
message = append(message, targetIDBytes...)
t.Logf("Sending message: opcode=0x%02x, len=%d, payload_len=%d", opcode, len(message), len(message)-1)
// Send message
err = conn.WriteMessage(websocket.BinaryMessage, message)
if err != nil {
t.Fatalf("Failed to send get_logs message: %v", err)
}
// Read response
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
messageType, response, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Failed to read get_logs response: %v", err)
}
t.Logf("Received response type %d, length %d bytes", messageType, len(response))
// Parse response
if len(response) < 1 {
t.Fatal("Empty response received")
}
// Parse the packet using the protocol
packet, err := parseResponsePacket(response)
if err != nil {
t.Fatalf("Failed to parse response packet: %v", err)
}
// Verify response is a data packet (0x04 = PacketTypeData)
if packet.PacketType != 0x04 {
t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType)
}
// Parse JSON payload
var logsResponse struct {
TargetID string `json:"target_id"`
Logs string `json:"logs"`
Truncated bool `json:"truncated"`
TotalLines int `json:"total_lines"`
}
if err := json.Unmarshal(packet.Payload, &logsResponse); err != nil {
t.Fatalf("Failed to parse logs JSON response: %v", err)
}
// Verify response fields
if logsResponse.TargetID != targetID {
t.Errorf("Expected target_id %s, got %s", targetID, logsResponse.TargetID)
}
if logsResponse.Logs == "" {
t.Error("Expected non-empty logs content")
}
t.Logf("Successfully received logs for target %s (%d lines, truncated=%v)",
logsResponse.TargetID, logsResponse.TotalLines, logsResponse.Truncated)
}
// testStreamLogsAPI tests the stream_logs WebSocket endpoint
func testStreamLogsAPI(t *testing.T, targetID string) {
t.Logf("Testing stream_logs API with target: %s", targetID)
wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort)
// Connect to WebSocket with retry
conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
t.Fatalf("Failed to connect to WebSocket after retries: %v", err)
}
defer conn.Close()
// Build binary message for stream_logs
// [opcode:1][api_key_hash:16][target_id_len:1][target_id:var]
opcode := byte(0x21) // OpcodeStreamLogs
apiKeyHash := make([]byte, 16)
targetIDBytes := []byte(targetID)
message := []byte{opcode}
message = append(message, apiKeyHash...)
message = append(message, byte(len(targetIDBytes)))
message = append(message, targetIDBytes...)
// Send message
err = conn.WriteMessage(websocket.BinaryMessage, message)
if err != nil {
t.Fatalf("Failed to send stream_logs message: %v", err)
}
// Read response
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
messageType, response, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Failed to read stream_logs response: %v", err)
}
t.Logf("Received stream response type %d, length %d bytes", messageType, len(response))
// Parse response
packet, err := parseResponsePacket(response)
if err != nil {
t.Fatalf("Failed to parse response packet: %v", err)
}
// Verify response
if packet.PacketType != 0x04 { // PacketTypeData
t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType)
}
var streamResponse struct {
TargetID string `json:"target_id"`
Streaming bool `json:"streaming"`
Message string `json:"message"`
}
if err := json.Unmarshal(packet.Payload, &streamResponse); err != nil {
t.Fatalf("Failed to parse stream JSON response: %v", err)
}
if streamResponse.TargetID != targetID {
t.Errorf("Expected target_id %s, got %s", targetID, streamResponse.TargetID)
}
if !streamResponse.Streaming {
t.Error("Expected streaming=true in response")
}
t.Logf("Successfully initiated log streaming for target %s", streamResponse.TargetID)
}
// testAttachDebugAPI tests the attach_debug WebSocket endpoint
func testAttachDebugAPI(t *testing.T, targetID, debugType string) {
t.Logf("Testing attach_debug API with target: %s, debug_type: %s", targetID, debugType)
wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort)
// Connect to WebSocket with retry
conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
t.Fatalf("Failed to connect to WebSocket after retries: %v", err)
}
defer conn.Close()
// Build binary message for attach_debug
// [opcode:1][api_key_hash:16][target_id_len:1][target_id:var][debug_type:var]
opcode := byte(0x22) // OpcodeAttachDebug
apiKeyHash := make([]byte, 16)
targetIDBytes := []byte(targetID)
debugTypeBytes := []byte(debugType)
message := []byte{opcode}
message = append(message, apiKeyHash...)
message = append(message, byte(len(targetIDBytes)))
message = append(message, targetIDBytes...)
message = append(message, debugTypeBytes...)
// Send message
err = conn.WriteMessage(websocket.BinaryMessage, message)
if err != nil {
t.Fatalf("Failed to send attach_debug message: %v", err)
}
// Read response
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
messageType, response, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Failed to read attach_debug response: %v", err)
}
t.Logf("Received debug response type %d, length %d bytes", messageType, len(response))
// Parse response
packet, err := parseResponsePacket(response)
if err != nil {
t.Fatalf("Failed to parse response packet: %v", err)
}
// Verify response
if packet.PacketType != 0x04 { // PacketTypeData
t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType)
}
var debugResponse struct {
TargetID string `json:"target_id"`
DebugType string `json:"debug_type"`
Attached bool `json:"attached"`
Message string `json:"message"`
Suggestion string `json:"suggestion"`
}
if err := json.Unmarshal(packet.Payload, &debugResponse); err != nil {
t.Fatalf("Failed to parse debug JSON response: %v", err)
}
if debugResponse.TargetID != targetID {
t.Errorf("Expected target_id %s, got %s", targetID, debugResponse.TargetID)
}
if debugResponse.DebugType != debugType {
t.Errorf("Expected debug_type %s, got %s", debugType, debugResponse.DebugType)
}
// Note: attached is false in stub implementation
t.Logf("Debug attachment response: target=%s, type=%s, attached=%v",
debugResponse.TargetID, debugResponse.DebugType, debugResponse.Attached)
}
// ResponsePacket represents a parsed WebSocket response packet
type ResponsePacket struct {
PacketType byte
Payload []byte
}
// parseResponsePacket parses a binary WebSocket response packet
func parseResponsePacket(data []byte) (*ResponsePacket, error) {
if len(data) < 9 { // min: type(1) + timestamp(8)
return nil, fmt.Errorf("packet too short: %d bytes", len(data))
}
packetType := data[0]
// Skip timestamp (8 bytes)
payloadStart := 9
// Handle different response types
switch packetType {
case 0x00: // PacketTypeSuccess
if len(data) <= payloadStart {
return &ResponsePacket{PacketType: packetType, Payload: nil}, nil
}
// Parse varint length + string
strLen, n := binary.Uvarint(data[payloadStart:])
if n <= 0 {
return nil, fmt.Errorf("invalid varint")
}
return &ResponsePacket{
PacketType: packetType,
Payload: data[payloadStart+n : payloadStart+n+int(strLen)],
}, nil
case 0x01: // PacketTypeError
if len(data) < payloadStart+1 {
return nil, fmt.Errorf("error packet too short")
}
errorCode := data[payloadStart]
// Skip error code and read message
msgStart := payloadStart + 1
if len(data) > msgStart && len(data) > msgStart+1 {
return nil, fmt.Errorf("server error (code: 0x%02x): %s", errorCode, string(data[msgStart:]))
}
return nil, fmt.Errorf("server error (code: 0x%02x)", errorCode)
case 0x04: // PacketTypeData
// Format after timestamp: [data_type_len:varint][data_type][payload_len:varint][payload]
if len(data) <= payloadStart {
return &ResponsePacket{PacketType: packetType, Payload: nil}, nil
}
// Read data_type length (varint)
typeLen, n1 := binary.Uvarint(data[payloadStart:])
if n1 <= 0 {
return nil, fmt.Errorf("invalid type length varint")
}
// Skip data_type string
payloadLenStart := payloadStart + n1 + int(typeLen)
if len(data) <= payloadLenStart {
return &ResponsePacket{PacketType: packetType, Payload: nil}, nil
}
// Read payload length (varint)
payloadLen, n2 := binary.Uvarint(data[payloadLenStart:])
if n2 <= 0 {
return nil, fmt.Errorf("invalid payload length varint")
}
payloadDataStart := payloadLenStart + n2
if len(data) < payloadDataStart+int(payloadLen) {
return nil, fmt.Errorf("data packet payload incomplete")
}
return &ResponsePacket{
PacketType: packetType,
Payload: data[payloadDataStart : payloadDataStart+int(payloadLen)],
}, nil
default:
// Unknown packet type, return raw data after timestamp
if len(data) > payloadStart {
return &ResponsePacket{
PacketType: packetType,
Payload: data[payloadStart:],
}, nil
}
return &ResponsePacket{PacketType: packetType}, nil
}
}
// cleanupDockerCompose stops and removes Docker Compose containers
func cleanupDockerCompose(t *testing.T, composeFile string) {
t.Log("Cleaning up Docker Compose...")
downCmd := exec.CommandContext(context.Background(),
"docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "down", "--remove-orphans", "--volumes")
downCmd.Stdout = os.Stdout
downCmd.Stderr = os.Stderr
if err := downCmd.Run(); err != nil {
t.Logf("Warning: Failed to cleanup Docker Compose: %v", err)
}
}
// waitForServicesHealthy waits for all services to be healthy
func waitForServicesHealthy(t *testing.T, projectName string, timeout time.Duration) bool {
start := time.Now()
for time.Since(start) < timeout {
// Check container status
psCmd := exec.CommandContext(context.Background(),
"docker", "ps", "--filter", fmt.Sprintf("name=%s", projectName), "--format", "{{.Names}}\t{{.Status}}")
output, err := psCmd.CombinedOutput()
if err == nil {
status := string(output)
t.Logf("Container status:\n%s", status)
// Check if all containers are healthy (not just "Up")
allHealthy := true
lines := strings.Split(status, "\n")
containerCount := 0
for _, line := range lines {
if strings.TrimSpace(line) == "" {
continue
}
containerCount++
// Must explicitly contain "healthy" - "Up (health: starting)" is NOT healthy
if !strings.Contains(line, "healthy") {
allHealthy = false
break
}
}
if allHealthy && containerCount >= 2 { // redis and api-server
t.Log("All services are healthy")
return true
}
}
time.Sleep(1 * time.Second)
}
t.Logf("Timeout waiting for services after %v", timeout)
return false
}
// TestLogsDebugHTTPHealthE2E tests the HTTP health endpoint for logs/debug services
func TestLogsDebugHTTPHealthE2E(t *testing.T) {
if os.Getenv("FETCH_ML_E2E_LOGS_DEBUG") != "1" {
t.Skip("Skipping LogsDebugHTTPHealthE2E (set FETCH_ML_E2E_LOGS_DEBUG=1 to enable)")
}
// Check if API is already running
healthURL := fmt.Sprintf("http://%s:%s/health", apiHost, apiPort)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(healthURL)
if err != nil {
t.Skipf("API server not available at %s: %v", healthURL, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Health check failed: status %d", resp.StatusCode)
}
// Parse health response
var healthResponse struct {
Status string `json:"status"`
Version string `json:"version"`
}
if err := json.NewDecoder(resp.Body).Decode(&healthResponse); err != nil {
t.Logf("Failed to parse health response: %v", err)
} else {
t.Logf("API health: status=%s, version=%s", healthResponse.Status, healthResponse.Version)
}
}

View file

@ -0,0 +1,63 @@
package tests
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// TestDuplicateDetection verifies the duplicate detection logic
func TestDuplicateDetection(t *testing.T) {
// Test 1: Same args + same commit = duplicate (would be detected)
commitID := "abc123def4567890abc1"
args1 := "--epochs 10 --lr 0.001"
args2 := "--epochs 10 --lr 0.001" // Same args
hash1 := helpers.ComputeParamsHash(args1)
hash2 := helpers.ComputeParamsHash(args2)
if hash1 != hash2 {
t.Error("Same args should produce same hash")
}
t.Logf("✓ Same args produce same hash: %s", hash1)
// Test 2: Different args = not duplicate
args3 := "--epochs 20 --lr 0.01" // Different
hash3 := helpers.ComputeParamsHash(args3)
if hash1 == hash3 {
t.Error("Different args should produce different hashes")
}
t.Logf("✓ Different args produce different hashes: %s vs %s", hash1, hash3)
// Test 3: Same dataset specs = same dataset_id
ds1 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}}
ds2 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}}
id1 := helpers.ComputeDatasetID(ds1, nil)
id2 := helpers.ComputeDatasetID(ds2, nil)
if id1 != id2 {
t.Error("Same dataset specs should produce same ID")
}
t.Logf("✓ Same dataset specs produce same ID: %s", id1)
// Test 4: Different dataset = different ID
ds3 := []queue.DatasetSpec{{Name: "cifar10", Checksum: "sha256:def456"}}
id3 := helpers.ComputeDatasetID(ds3, nil)
if id1 == id3 {
t.Error("Different datasets should produce different IDs")
}
t.Logf("✓ Different datasets produce different IDs: %s vs %s", id1, id3)
// Test 5: Composite key logic
t.Log("\n=== Composite Key Detection ===")
t.Logf("Job 1: commit=%s, dataset_id=%s, params_hash=%s", commitID, id1, hash1)
t.Logf("Job 2: commit=%s, dataset_id=%s, params_hash=%s", commitID, id2, hash2)
t.Log("→ These would be detected as DUPLICATE (same commit + dataset + params)")
t.Logf("Job 3: commit=%s, dataset_id=%s, params_hash=%s", commitID, id1, hash3)
t.Log("→ This would NOT be duplicate (different params)")
}

View file

@ -236,10 +236,10 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
expMgr,
"",
taskQueue,
nil,
nil,
nil,
nil,
nil, // db
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
)
server := httptest.NewServer(wsHandler)
defer server.Close()

View file

@ -65,9 +65,9 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) (
dataDir,
tq,
db,
nil,
nil,
nil,
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
)
server := httptest.NewServer(handler)
return server, tq, expManager, s, db

View file

@ -0,0 +1,95 @@
package api_test
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// ProcessTest demonstrates the duplicate detection process step by step
func TestDuplicateDetectionProcess(t *testing.T) {
t.Log("=== Duplicate Detection Process Test ===")
// Step 1: First job submission
t.Log("\n1. First job submission:")
commitID := "abc123def456"
args1 := "--epochs 10 --lr 0.001"
datasets := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}}
datasetID1 := helpers.ComputeDatasetID(datasets, nil)
paramsHash1 := helpers.ComputeParamsHash(args1)
t.Logf(" Commit ID: %s", commitID)
t.Logf(" Dataset ID: %s (computed from %d datasets)", datasetID1, len(datasets))
t.Logf(" Params Hash: %s (computed from args: %s)", paramsHash1, args1)
t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID1, paramsHash1)
// Step 2: Second job with SAME parameters (should be duplicate)
t.Log("\n2. Second job submission (same params):")
args2 := "--epochs 10 --lr 0.001" // Same args
datasets2 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} // Same dataset
datasetID2 := helpers.ComputeDatasetID(datasets2, nil)
paramsHash2 := helpers.ComputeParamsHash(args2)
t.Logf(" Commit ID: %s", commitID)
t.Logf(" Dataset ID: %s", datasetID2)
t.Logf(" Params Hash: %s", paramsHash2)
t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID2, paramsHash2)
// Verify they're the same
if datasetID1 == datasetID2 && paramsHash1 == paramsHash2 {
t.Log(" ✓ DUPLICATE DETECTED - same composite key!")
} else {
t.Error(" ✗ Should have been detected as duplicate")
}
// Step 3: Third job with DIFFERENT parameters (not duplicate)
t.Log("\n3. Third job submission (different params):")
args3 := "--epochs 20 --lr 0.01" // Different args
datasets3 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} // Same dataset
datasetID3 := helpers.ComputeDatasetID(datasets3, nil)
paramsHash3 := helpers.ComputeParamsHash(args3)
t.Logf(" Commit ID: %s", commitID)
t.Logf(" Dataset ID: %s", datasetID3)
t.Logf(" Params Hash: %s", paramsHash3)
t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID3, paramsHash3)
// Verify they're different
if paramsHash1 != paramsHash3 {
t.Log(" ✓ NOT A DUPLICATE - different params_hash")
} else {
t.Error(" ✗ Should have different params_hash")
}
// Step 4: Fourth job with DIFFERENT dataset (not duplicate)
t.Log("\n4. Fourth job submission (different dataset):")
args4 := "--epochs 10 --lr 0.001" // Same args
datasets4 := []queue.DatasetSpec{{Name: "cifar10", Checksum: "sha256:def456"}} // Different dataset
datasetID4 := helpers.ComputeDatasetID(datasets4, nil)
paramsHash4 := helpers.ComputeParamsHash(args4)
t.Logf(" Commit ID: %s", commitID)
t.Logf(" Dataset ID: %s", datasetID4)
t.Logf(" Params Hash: %s", paramsHash4)
t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID4, paramsHash4)
// Verify they're different
if datasetID1 != datasetID4 {
t.Log(" ✓ NOT A DUPLICATE - different dataset_id")
} else {
t.Error(" ✗ Should have different dataset_id")
}
// Step 5: Summary
t.Log("\n5. Summary:")
t.Log(" - Jobs 1 & 2: Same commit_id + dataset_id + params_hash = DUPLICATE")
t.Log(" - Job 3: Different params_hash = NOT DUPLICATE")
t.Log(" - Job 4: Different dataset_id = NOT DUPLICATE")
t.Log("\n The composite key (commit_id, dataset_id, params_hash) ensures")
t.Log(" only truly identical experiments are flagged as duplicates.")
}

View file

@ -0,0 +1,225 @@
package helpers
import (
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
)
func TestDBContext(t *testing.T) {
ctx, cancel := helpers.DBContext(3 * time.Second)
defer cancel()
deadline, ok := ctx.Deadline()
if !ok {
t.Error("expected deadline to be set")
}
// Deadline should be within a reasonable time window
now := time.Now()
if deadline.Before(now) {
t.Error("deadline is in the past")
}
if deadline.After(now.Add(4 * time.Second)) {
t.Error("deadline is too far in the future")
}
// Context should not be cancelled
select {
case <-ctx.Done():
t.Error("context should not be cancelled yet")
default:
// Good
}
}
func TestDBContextShort(t *testing.T) {
ctx, cancel := helpers.DBContextShort()
defer cancel()
deadline, ok := ctx.Deadline()
if !ok {
t.Error("expected deadline to be set")
}
now := time.Now()
diff := deadline.Sub(now)
// Should be around 3 seconds
if diff < 2*time.Second || diff > 4*time.Second {
t.Errorf("expected deadline ~3s, got %v", diff)
}
}
func TestDBContextMedium(t *testing.T) {
ctx, cancel := helpers.DBContextMedium()
defer cancel()
deadline, ok := ctx.Deadline()
if !ok {
t.Error("expected deadline to be set")
}
now := time.Now()
diff := deadline.Sub(now)
// Should be around 5 seconds
if diff < 4*time.Second || diff > 6*time.Second {
t.Errorf("expected deadline ~5s, got %v", diff)
}
}
func TestDBContextLong(t *testing.T) {
ctx, cancel := helpers.DBContextLong()
defer cancel()
deadline, ok := ctx.Deadline()
if !ok {
t.Error("expected deadline to be set")
}
now := time.Now()
diff := deadline.Sub(now)
// Should be around 10 seconds
if diff < 9*time.Second || diff > 11*time.Second {
t.Errorf("expected deadline ~10s, got %v", diff)
}
}
func TestDBContextCancellation(t *testing.T) {
ctx, cancel := helpers.DBContextShort()
// Cancel immediately
cancel()
// Context should be cancelled
select {
case <-ctx.Done():
// Good - context was cancelled
default:
t.Error("context should be cancelled after calling cancel()")
}
// Check error
if ctx.Err() == nil {
t.Error("expected non-nil error after cancellation")
}
}
func TestStringSliceContains(t *testing.T) {
tests := []struct {
name string
slice []string
item string
want bool
}{
{
name: "contains",
slice: []string{"a", "b", "c"},
item: "b",
want: true,
},
{
name: "not contains",
slice: []string{"a", "b", "c"},
item: "d",
want: false,
},
{
name: "empty slice",
slice: []string{},
item: "a",
want: false,
},
{
name: "nil slice",
slice: nil,
item: "a",
want: false,
},
{
name: "empty string item",
slice: []string{"a", "", "c"},
item: "",
want: true,
},
{
name: "case sensitive",
slice: []string{"Apple", "Banana"},
item: "apple",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.StringSliceContains(tt.slice, tt.item)
if got != tt.want {
t.Errorf("StringSliceContains() = %v, want %v", got, tt.want)
}
})
}
}
func TestStringSliceFilter(t *testing.T) {
tests := []struct {
name string
slice []string
predicate func(string) bool
want []string
}{
{
name: "filter by prefix",
slice: []string{"apple", "banana", "apricot", "cherry"},
predicate: func(s string) bool { return len(s) > 0 && s[0] == 'a' },
want: []string{"apple", "apricot"},
},
{
name: "filter empty strings",
slice: []string{"a", "", "b", "", "c"},
predicate: func(s string) bool { return s != "" },
want: []string{"a", "b", "c"},
},
{
name: "all match",
slice: []string{"a", "b", "c"},
predicate: func(s string) bool { return true },
want: []string{"a", "b", "c"},
},
{
name: "none match",
slice: []string{"a", "b", "c"},
predicate: func(s string) bool { return false },
want: []string{},
},
{
name: "empty slice",
slice: []string{},
predicate: func(s string) bool { return true },
want: []string{},
},
{
name: "nil slice",
slice: nil,
predicate: func(s string) bool { return true },
want: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.StringSliceFilter(tt.slice, tt.predicate)
if len(got) != len(tt.want) {
t.Errorf("StringSliceFilter() returned %d items, want %d", len(got), len(tt.want))
return
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("StringSliceFilter()[%d] = %q, want %q", i, got[i], tt.want[i])
}
}
})
}
}

View file

@ -0,0 +1,137 @@
package helpers_test
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/queue"
)
func TestComputeDatasetID(t *testing.T) {
tests := []struct {
name string
datasetSpecs []queue.DatasetSpec
datasets []string
want string
}{
{
name: "both empty",
datasetSpecs: nil,
datasets: nil,
want: "",
},
{
name: "only datasets",
datasetSpecs: nil,
datasets: []string{"dataset1", "dataset2"},
want: "", // will be a hash
},
{
name: "dataset specs with checksums",
datasetSpecs: []queue.DatasetSpec{
{Name: "ds1", Checksum: "abc123"},
{Name: "ds2", Checksum: "def456"},
},
datasets: nil,
want: "", // will be a hash
},
{
name: "dataset specs without checksums",
datasetSpecs: []queue.DatasetSpec{
{Name: "ds1"},
{Name: "ds2"},
},
datasets: nil,
want: "", // will use names
},
{
name: "checksums take precedence",
datasetSpecs: []queue.DatasetSpec{{Name: "ds1", Checksum: "xyz789"}},
datasets: []string{"dataset1"},
want: "", // should use checksum from specs
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.ComputeDatasetID(tt.datasetSpecs, tt.datasets)
if tt.want == "" {
// Just verify it returns something or empty as expected
if len(tt.datasetSpecs) == 0 && len(tt.datasets) == 0 && got != "" {
t.Errorf("ComputeDatasetID() = %q, want empty string", got)
}
} else if got != tt.want {
t.Errorf("ComputeDatasetID() = %q, want %q", got, tt.want)
}
})
}
}
func TestComputeParamsHash(t *testing.T) {
tests := []struct {
name string
args string
want string
}{
{
name: "empty args",
args: "",
want: "",
},
{
name: "simple args",
args: "--lr 0.01 --epochs 10",
want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // sha256 of trimmed args
},
{
name: "args with spaces",
args: " --lr 0.01 ",
want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // sha256 of trimmed args
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.ComputeParamsHash(tt.args)
// Just verify it returns a string (hash computation is deterministic)
if tt.want == "" && got != "" {
t.Errorf("ComputeParamsHash() expected empty, got %q", got)
}
if tt.want != "" && got == "" {
t.Errorf("ComputeParamsHash() expected non-empty hash")
}
})
}
}
func TestComputeParamsHash_Deterministic(t *testing.T) {
args := "--lr 0.01 --epochs 10"
hash1 := helpers.ComputeParamsHash(args)
hash2 := helpers.ComputeParamsHash(args)
if hash1 != hash2 {
t.Error("ComputeParamsHash() should be deterministic")
}
// Different args should produce different hashes
differentArgs := "--lr 0.02 --epochs 10"
differentHash := helpers.ComputeParamsHash(differentArgs)
if hash1 == differentHash {
t.Error("ComputeParamsHash() should produce different hashes for different inputs")
}
}
func TestComputeParamsHash_Whitespace(t *testing.T) {
// Same args with different whitespace should produce the same hash
hash1 := helpers.ComputeParamsHash("--lr 0.01 --epochs 10")
hash2 := helpers.ComputeParamsHash(" --lr 0.01 --epochs 10 ")
hash3 := helpers.ComputeParamsHash("--lr 0.01 --epochs 10")
if hash1 != hash2 {
t.Error("ComputeParamsHash() should handle leading/trailing whitespace consistently")
}
// Note: internal whitespace differences may or may not produce different hashes
// depending on implementation details
_ = hash3
}

View file

@ -0,0 +1,451 @@
package helpers_test
import (
"bytes"
"testing"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
)
func TestNewPayloadParser(t *testing.T) {
payload := []byte{0x01, 0x02, 0x03, 0x04, 0x05}
p := helpers.NewPayloadParser(payload, 2)
if p.Offset() != 2 {
t.Errorf("expected offset 2, got %d", p.Offset())
}
if !bytes.Equal(p.Payload(), payload) {
t.Error("payload mismatch")
}
}
func TestPayloadParser_ParseByte(t *testing.T) {
tests := []struct {
name string
payload []byte
start int
want byte
wantErr bool
}{
{
name: "valid byte",
payload: []byte{0x00, 0x01, 0x02},
start: 1,
want: 0x01,
wantErr: false,
},
{
name: "end of payload",
payload: []byte{0x00, 0x01},
start: 2,
want: 0,
wantErr: true,
},
{
name: "empty payload",
payload: []byte{},
start: 0,
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := helpers.NewPayloadParser(tt.payload, tt.start)
got, err := p.ParseByte()
if (err != nil) != tt.wantErr {
t.Errorf("ParseByte() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ParseByte() = %v, want %v", got, tt.want)
}
})
}
}
func TestPayloadParser_ParseUint16(t *testing.T) {
tests := []struct {
name string
payload []byte
start int
want uint16
wantErr bool
}{
{
name: "valid uint16",
payload: []byte{0x01, 0x02, 0x03, 0x04},
start: 0,
want: 0x0102,
wantErr: false,
},
{
name: "another valid uint16",
payload: []byte{0x00, 0xAB, 0xCD},
start: 1,
want: 0xABCD,
wantErr: false,
},
{
name: "too short",
payload: []byte{0x00, 0x01},
start: 1,
want: 0,
wantErr: true,
},
{
name: "empty at end",
payload: []byte{0x00, 0x01, 0x02},
start: 3,
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := helpers.NewPayloadParser(tt.payload, tt.start)
got, err := p.ParseUint16()
if (err != nil) != tt.wantErr {
t.Errorf("ParseUint16() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ParseUint16() = %v, want %v", got, tt.want)
}
})
}
}
func TestPayloadParser_ParseLengthPrefixedString(t *testing.T) {
tests := []struct {
name string
payload []byte
start int
want string
wantErr bool
}{
{
name: "valid string",
payload: []byte{0x05, 'h', 'e', 'l', 'l', 'o'},
start: 0,
want: "hello",
wantErr: false,
},
{
name: "empty string",
payload: []byte{0x00, 0x01, 0x02},
start: 0,
want: "",
wantErr: false,
},
{
name: "from offset",
payload: []byte{0x00, 0x03, 'f', 'o', 'o'},
start: 1,
want: "foo",
wantErr: false,
},
{
name: "too short for length",
payload: []byte{0x05, 'h', 'i'},
start: 0,
want: "",
wantErr: true,
},
{
name: "no length byte",
payload: []byte{},
start: 0,
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := helpers.NewPayloadParser(tt.payload, tt.start)
got, err := p.ParseLengthPrefixedString()
if (err != nil) != tt.wantErr {
t.Errorf("ParseLengthPrefixedString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ParseLengthPrefixedString() = %v, want %v", got, tt.want)
}
})
}
}
func TestPayloadParser_ParseUint16PrefixedString(t *testing.T) {
tests := []struct {
name string
payload []byte
start int
want string
wantErr bool
}{
{
name: "valid string",
payload: []byte{0x00, 0x05, 'h', 'e', 'l', 'l', 'o'},
start: 0,
want: "hello",
wantErr: false,
},
{
name: "empty string",
payload: []byte{0x00, 0x00, 0x01, 0x02},
start: 0,
want: "",
wantErr: false,
},
{
name: "from offset",
payload: []byte{0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r'},
start: 2,
want: "bar",
wantErr: false,
},
{
name: "too short for length",
payload: []byte{0x00},
start: 0,
want: "",
wantErr: true,
},
{
name: "too short for string",
payload: []byte{0x00, 0x05, 'h', 'i'},
start: 0,
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := helpers.NewPayloadParser(tt.payload, tt.start)
got, err := p.ParseUint16PrefixedString()
if (err != nil) != tt.wantErr {
t.Errorf("ParseUint16PrefixedString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ParseUint16PrefixedString() = %v, want %v", got, tt.want)
}
})
}
}
func TestPayloadParser_ParseBool(t *testing.T) {
tests := []struct {
name string
payload []byte
start int
want bool
wantErr bool
}{
{
name: "true",
payload: []byte{0x01, 0x00},
start: 0,
want: true,
wantErr: false,
},
{
name: "false",
payload: []byte{0x00, 0x01},
start: 0,
want: false,
wantErr: false,
},
{
name: "non-zero is true",
payload: []byte{0xFF},
start: 0,
want: true,
wantErr: false,
},
{
name: "empty",
payload: []byte{},
start: 0,
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := helpers.NewPayloadParser(tt.payload, tt.start)
got, err := p.ParseBool()
if (err != nil) != tt.wantErr {
t.Errorf("ParseBool() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ParseBool() = %v, want %v", got, tt.want)
}
})
}
}
func TestPayloadParser_ParseFixedBytes(t *testing.T) {
tests := []struct {
name string
payload []byte
start int
length int
want []byte
wantErr bool
}{
{
name: "valid",
payload: []byte{0x00, 0x01, 0x02, 0x03, 0x04},
start: 1,
length: 3,
want: []byte{0x01, 0x02, 0x03},
wantErr: false,
},
{
name: "exact length",
payload: []byte{0x00, 0x01, 0x02},
start: 0,
length: 3,
want: []byte{0x00, 0x01, 0x02},
wantErr: false,
},
{
name: "too long",
payload: []byte{0x00, 0x01},
start: 0,
length: 3,
want: nil,
wantErr: true,
},
{
name: "from end",
payload: []byte{0x00, 0x01, 0x02},
start: 2,
length: 2,
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := helpers.NewPayloadParser(tt.payload, tt.start)
got, err := p.ParseFixedBytes(tt.length)
if (err != nil) != tt.wantErr {
t.Errorf("ParseFixedBytes() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !bytes.Equal(got, tt.want) {
t.Errorf("ParseFixedBytes() = %v, want %v", got, tt.want)
}
})
}
}
func TestPayloadParser_Offset(t *testing.T) {
payload := []byte{0x00, 0x01, 0x02, 0x03}
p := helpers.NewPayloadParser(payload, 1)
if p.Offset() != 1 {
t.Errorf("initial Offset() = %d, want 1", p.Offset())
}
p.ParseByte()
if p.Offset() != 2 {
t.Errorf("after ParseByte() Offset() = %d, want 2", p.Offset())
}
}
func TestPayloadParser_Remaining(t *testing.T) {
payload := []byte{0x00, 0x01, 0x02, 0x03}
p := helpers.NewPayloadParser(payload, 2)
remaining := p.Remaining()
if !bytes.Equal(remaining, []byte{0x02, 0x03}) {
t.Errorf("Remaining() = %v, want [0x02, 0x03]", remaining)
}
// At end
p = helpers.NewPayloadParser(payload, 4)
if p.Remaining() != nil {
t.Errorf("at end Remaining() = %v, want nil", p.Remaining())
}
// Beyond end
p = helpers.NewPayloadParser(payload, 5)
if p.Remaining() != nil {
t.Errorf("beyond end Remaining() = %v, want nil", p.Remaining())
}
}
func TestPayloadParser_HasRemaining(t *testing.T) {
payload := []byte{0x00, 0x01}
tests := []struct {
start int
want bool
}{
{0, true},
{1, true},
{2, false},
{3, false},
}
for _, tt := range tests {
p := helpers.NewPayloadParser(payload, tt.start)
if got := p.HasRemaining(); got != tt.want {
t.Errorf("HasRemaining() with start=%d = %v, want %v", tt.start, got, tt.want)
}
}
}
func TestPayloadParser_ChainedParsing(t *testing.T) {
// Simulate a real protocol: [api_key:2][name_len:1][name:var][value:2]
payload := []byte{
0xAB, 0xCD, // api_key
0x05, // name_len
'h', 'e', 'l', 'l', 'o', // name
0x00, 0x42, // value
}
p := helpers.NewPayloadParser(payload, 0)
// Parse api key (2 bytes)
apiKey, err := p.ParseFixedBytes(2)
if err != nil {
t.Fatalf("failed to parse api key: %v", err)
}
if !bytes.Equal(apiKey, []byte{0xAB, 0xCD}) {
t.Errorf("api key = %v, want [0xAB, 0xCD]", apiKey)
}
// Parse name
name, err := p.ParseLengthPrefixedString()
if err != nil {
t.Fatalf("failed to parse name: %v", err)
}
if name != "hello" {
t.Errorf("name = %q, want 'hello'", name)
}
// Parse value
value, err := p.ParseUint16()
if err != nil {
t.Fatalf("failed to parse value: %v", err)
}
if value != 0x0042 {
t.Errorf("value = 0x%X, want 0x0042", value)
}
// Should be at end
if p.HasRemaining() {
t.Error("expected no remaining bytes")
}
}

View file

@ -0,0 +1,345 @@
package helpers
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/queue"
)
func TestNewTaskErrorMapper(t *testing.T) {
mapper := helpers.NewTaskErrorMapper()
if mapper == nil {
t.Error("NewTaskErrorMapper() returned nil")
}
}
func TestTaskErrorMapper_MapJupyterError(t *testing.T) {
mapper := helpers.NewTaskErrorMapper()
tests := []struct {
name string
task *queue.Task
want helpers.ErrorCode
}{
{
name: "nil task",
task: nil,
want: 0x00, // ErrorCodeUnknownError
},
{
name: "cancelled task",
task: &queue.Task{Status: "cancelled"},
want: 0x24, // ErrorCodeJobCancelled
},
{
name: "oom error",
task: &queue.Task{Status: "failed", Error: "out of memory"},
want: 0x30, // ErrorCodeOutOfMemory
},
{
name: "oom shorthand",
task: &queue.Task{Status: "failed", Error: "OOM killed"},
want: 0x30, // ErrorCodeOutOfMemory
},
{
name: "disk full",
task: &queue.Task{Status: "failed", Error: "no space left on device"},
want: 0x31, // ErrorCodeDiskFull
},
{
name: "disk full alt",
task: &queue.Task{Status: "failed", Error: "disk full"},
want: 0x31, // ErrorCodeDiskFull
},
{
name: "rate limit",
task: &queue.Task{Status: "failed", Error: "rate limit exceeded"},
want: 0x33, // ErrorCodeServiceUnavailable
},
{
name: "throttle",
task: &queue.Task{Status: "failed", Error: "request throttled"},
want: 0x33, // ErrorCodeServiceUnavailable
},
{
name: "timeout",
task: &queue.Task{Status: "failed", Error: "timed out waiting"},
want: 0x14, // ErrorCodeTimeout
},
{
name: "deadline",
task: &queue.Task{Status: "failed", Error: "context deadline exceeded"},
want: 0x14, // ErrorCodeTimeout
},
{
name: "connection refused",
task: &queue.Task{Status: "failed", Error: "connection refused"},
want: 0x12, // ErrorCodeNetworkError
},
{
name: "connection reset",
task: &queue.Task{Status: "failed", Error: "connection reset by peer"},
want: 0x12, // ErrorCodeNetworkError
},
{
name: "network unreachable",
task: &queue.Task{Status: "failed", Error: "network unreachable"},
want: 0x12, // ErrorCodeNetworkError
},
{
name: "queue not configured",
task: &queue.Task{Status: "failed", Error: "queue not configured"},
want: 0x32, // ErrorCodeInvalidConfiguration
},
{
name: "generic failed",
task: &queue.Task{Status: "failed", Error: "something went wrong"},
want: 0x23, // ErrorCodeJobExecutionFailed
},
{
name: "unknown status",
task: &queue.Task{Status: "unknown", Error: "unknown error"},
want: 0x00, // ErrorCodeUnknownError
},
{
name: "case insensitive - cancelled",
task: &queue.Task{Status: "CANCELLED"},
want: 0x24, // ErrorCodeJobCancelled
},
{
name: "case insensitive - oom",
task: &queue.Task{Status: "FAILED", Error: "OUT OF MEMORY"},
want: 0x30, // ErrorCodeOutOfMemory
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := mapper.MapJupyterError(tt.task)
if got != tt.want {
t.Errorf("MapJupyterError() = 0x%02X, want 0x%02X", got, tt.want)
}
})
}
}
func TestTaskErrorMapper_MapError(t *testing.T) {
mapper := helpers.NewTaskErrorMapper()
tests := []struct {
name string
task *queue.Task
defaultCode helpers.ErrorCode
want helpers.ErrorCode
}{
{
name: "nil task returns default",
task: nil,
defaultCode: 0x11, // ErrorCodeDatabaseError
want: 0x11,
},
{
name: "cancelled",
task: &queue.Task{Status: "cancelled"},
defaultCode: 0x00,
want: 0x24, // ErrorCodeJobCancelled
},
{
name: "oom",
task: &queue.Task{Status: "failed", Error: "oom"},
defaultCode: 0x00,
want: 0x30, // ErrorCodeOutOfMemory
},
{
name: "timeout",
task: &queue.Task{Status: "failed", Error: "timeout"},
defaultCode: 0x00,
want: 0x14, // ErrorCodeTimeout
},
{
name: "generic failed",
task: &queue.Task{Status: "failed", Error: "generic error"},
defaultCode: 0x00,
want: 0x23, // ErrorCodeJobExecutionFailed
},
{
name: "unknown status returns default",
task: &queue.Task{Status: "weird", Error: "unknown"},
defaultCode: 0x01,
want: 0x01,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := mapper.MapError(tt.task, tt.defaultCode)
if got != tt.want {
t.Errorf("MapError() = 0x%02X, want 0x%02X", got, tt.want)
}
})
}
}
func TestParseResourceRequest(t *testing.T) {
tests := []struct {
name string
payload []byte
want *helpers.ResourceRequest
wantErr bool
}{
{
name: "empty payload",
payload: []byte{},
want: nil,
wantErr: false,
},
{
name: "nil payload",
payload: nil,
want: nil,
wantErr: false,
},
{
name: "valid minimal",
payload: []byte{4, 8, 1, 0},
want: &helpers.ResourceRequest{CPU: 4, MemoryGB: 8, GPU: 1, GPUMemory: ""},
wantErr: false,
},
{
name: "valid with gpu memory",
payload: []byte{8, 16, 2, 4, '8', 'G', 'B', '!'},
want: &helpers.ResourceRequest{CPU: 8, MemoryGB: 16, GPU: 2, GPUMemory: "8GB!"},
wantErr: false,
},
{
name: "too short",
payload: []byte{1, 2},
want: nil,
wantErr: true,
},
{
name: "invalid gpu mem length",
payload: []byte{1, 2, 1, 10, 'a'},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := helpers.ParseResourceRequest(tt.payload)
if (err != nil) != tt.wantErr {
t.Errorf("ParseResourceRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want == nil {
if got != nil {
t.Errorf("ParseResourceRequest() = %v, want nil", got)
}
} else if got == nil {
t.Errorf("ParseResourceRequest() = nil, want %v", tt.want)
} else {
if got.CPU != tt.want.CPU {
t.Errorf("CPU = %d, want %d", got.CPU, tt.want.CPU)
}
if got.MemoryGB != tt.want.MemoryGB {
t.Errorf("MemoryGB = %d, want %d", got.MemoryGB, tt.want.MemoryGB)
}
if got.GPU != tt.want.GPU {
t.Errorf("GPU = %d, want %d", got.GPU, tt.want.GPU)
}
if got.GPUMemory != tt.want.GPUMemory {
t.Errorf("GPUMemory = %q, want %q", got.GPUMemory, tt.want.GPUMemory)
}
}
})
}
}
func TestMarshalJSONOrEmpty(t *testing.T) {
tests := []struct {
name string
data interface{}
want []byte
}{
{
name: "simple map",
data: map[string]string{"key": "value"},
want: []byte(`{"key":"value"}`),
},
{
name: "string slice",
data: []string{"a", "b", "c"},
want: []byte(`["a","b","c"]`),
},
{
name: "empty slice",
data: []int{},
want: []byte(`[]`),
},
{
name: "nil",
data: nil,
want: []byte(`null`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.MarshalJSONOrEmpty(tt.data)
// For valid JSON, compare strings since JSON formatting might vary
gotStr := string(got)
wantStr := string(tt.want)
if gotStr != wantStr {
t.Errorf("MarshalJSONOrEmpty() = %q, want %q", gotStr, wantStr)
}
})
}
}
func TestMarshalJSONOrEmpty_ErrorCase(t *testing.T) {
// Test with a value that can't be marshaled (function)
got := helpers.MarshalJSONOrEmpty(func() {})
want := []byte("[]")
if string(got) != string(want) {
t.Errorf("MarshalJSONOrEmpty() with invalid data = %q, want %q", string(got), string(want))
}
}
func TestMarshalJSONBytes(t *testing.T) {
data := map[string]int{"count": 42}
got, err := helpers.MarshalJSONBytes(data)
if err != nil {
t.Errorf("MarshalJSONBytes() unexpected error: %v", err)
}
want := `{"count":42}`
if string(got) != want {
t.Errorf("MarshalJSONBytes() = %q, want %q", string(got), want)
}
}
func TestIsEmptyJSON(t *testing.T) {
tests := []struct {
name string
data []byte
want bool
}{
{"empty", []byte{}, true},
{"null", []byte("null"), true},
{"empty array", []byte("[]"), true},
{"empty object", []byte("{}"), true},
{"whitespace", []byte(" "), true},
{"data", []byte(`{"key":"value"}`), false},
{"non-empty array", []byte("[1,2,3]"), false},
{"null with spaces", []byte(" null "), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := helpers.IsEmptyJSON(tt.data); got != tt.want {
t.Errorf("IsEmptyJSON() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,486 @@
package helpers
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/api/helpers"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
)
func TestValidateCommitIDFormat(t *testing.T) {
tests := []struct {
name string
commitID string
wantOk bool
wantErr string
}{
{
name: "valid commit ID",
commitID: "aabbccddeeff00112233445566778899aabbccdd",
wantOk: true,
wantErr: "",
},
{
name: "too short",
commitID: "aabbcc",
wantOk: false,
wantErr: "invalid commit_id length",
},
{
name: "too long",
commitID: "aabbccddeeff00112233445566778899aabbccddeeff",
wantOk: false,
wantErr: "invalid commit_id length",
},
{
name: "invalid hex",
commitID: "gggggggggggggggggggggggggggggggggggggggg",
wantOk: false,
wantErr: "invalid commit_id hex",
},
{
name: "mixed case valid",
commitID: "AABBCCDDEEFF00112233445566778899AABBCCDD",
wantOk: true,
wantErr: "",
},
{
name: "empty",
commitID: "",
wantOk: false,
wantErr: "invalid commit_id length",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotOk, gotErr := helpers.ValidateCommitIDFormat(tt.commitID)
if gotOk != tt.wantOk {
t.Errorf("ValidateCommitIDFormat() ok = %v, want %v", gotOk, tt.wantOk)
}
if gotErr != tt.wantErr {
t.Errorf("ValidateCommitIDFormat() err = %q, want %q", gotErr, tt.wantErr)
}
})
}
}
func TestShouldRequireRunManifest(t *testing.T) {
tests := []struct {
name string
status string
want bool
}{
{"running", "running", true},
{"completed", "completed", true},
{"failed", "failed", true},
{"queued", "queued", false},
{"pending", "pending", false},
{"cancelled", "cancelled", false},
{"unknown", "unknown", false},
{"empty", "", false},
{"RUNNING uppercase", "RUNNING", true},
{"Completed mixed", "Completed", true},
{" Running with space", " running", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
task := &queue.Task{Status: tt.status}
got := helpers.ShouldRequireRunManifest(task)
if got != tt.want {
t.Errorf("ShouldRequireRunManifest() = %v, want %v", got, tt.want)
}
})
}
}
func TestExpectedRunManifestBucketForStatus(t *testing.T) {
tests := []struct {
name string
status string
wantBucket string
wantOk bool
}{
{"queued", "queued", "pending", true},
{"pending", "pending", "pending", true},
{"running", "running", "running", true},
{"completed", "completed", "finished", true},
{"finished", "finished", "finished", true},
{"failed", "failed", "failed", true},
{"unknown", "unknown", "", false},
{"empty", "", "", false},
{"RUNNING uppercase", "RUNNING", "running", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotBucket, gotOk := helpers.ExpectedRunManifestBucketForStatus(tt.status)
if gotBucket != tt.wantBucket {
t.Errorf("ExpectedRunManifestBucketForStatus() bucket = %q, want %q", gotBucket, tt.wantBucket)
}
if gotOk != tt.wantOk {
t.Errorf("ExpectedRunManifestBucketForStatus() ok = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func TestValidateTaskIDMatch(t *testing.T) {
tests := []struct {
name string
rmTaskID string
expectedID string
wantOk bool
wantExpected string
wantActual string
}{
{
name: "match",
rmTaskID: "task-123",
expectedID: "task-123",
wantOk: true,
wantExpected: "task-123",
wantActual: "task-123",
},
{
name: "mismatch",
rmTaskID: "task-123",
expectedID: "task-456",
wantOk: false,
wantExpected: "task-456",
wantActual: "task-123",
},
{
name: "empty rm task ID",
rmTaskID: "",
expectedID: "task-123",
wantOk: false,
wantExpected: "task-123",
wantActual: "",
},
{
name: "whitespace trimmed - match",
rmTaskID: "task-123",
expectedID: "task-123",
wantOk: true,
wantExpected: "task-123",
wantActual: "task-123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rm := &manifest.RunManifest{TaskID: tt.rmTaskID}
got := helpers.ValidateTaskIDMatch(rm, tt.expectedID)
if got.OK != tt.wantOk {
t.Errorf("ValidateTaskIDMatch() OK = %v, want %v", got.OK, tt.wantOk)
}
if got.Expected != tt.wantExpected {
t.Errorf("ValidateTaskIDMatch() Expected = %q, want %q", got.Expected, tt.wantExpected)
}
if got.Actual != tt.wantActual {
t.Errorf("ValidateTaskIDMatch() Actual = %q, want %q", got.Actual, tt.wantActual)
}
})
}
}
func TestValidateCommitIDMatch(t *testing.T) {
tests := []struct {
name string
rmCommitID string
expectedID string
wantOk bool
wantExpected string
wantActual string
}{
{
name: "both empty",
rmCommitID: "",
expectedID: "",
wantOk: true,
wantExpected: "",
wantActual: "",
},
{
name: "match",
rmCommitID: "abc123",
expectedID: "abc123",
wantOk: true,
wantExpected: "abc123",
wantActual: "abc123",
},
{
name: "mismatch",
rmCommitID: "abc123",
expectedID: "def456",
wantOk: false,
wantExpected: "def456",
wantActual: "abc123",
},
{
name: "expected empty",
rmCommitID: "abc123",
expectedID: "",
wantOk: true,
wantExpected: "",
wantActual: "",
},
{
name: "rm empty",
rmCommitID: "",
expectedID: "abc123",
wantOk: true,
wantExpected: "abc123",
wantActual: "",
},
{
name: "whitespace trimmed",
rmCommitID: " abc123 ",
expectedID: "abc123",
wantOk: true,
wantExpected: "abc123",
wantActual: "abc123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.ValidateCommitIDMatch(tt.rmCommitID, tt.expectedID)
if got.OK != tt.wantOk {
t.Errorf("ValidateCommitIDMatch() OK = %v, want %v", got.OK, tt.wantOk)
}
if got.Expected != tt.wantExpected {
t.Errorf("ValidateCommitIDMatch() Expected = %q, want %q", got.Expected, tt.wantExpected)
}
if got.Actual != tt.wantActual {
t.Errorf("ValidateCommitIDMatch() Actual = %q, want %q", got.Actual, tt.wantActual)
}
})
}
}
func TestValidateDepsProvenance(t *testing.T) {
tests := []struct {
name string
wantName string
wantSHA string
gotName string
gotSHA string
wantOk bool
wantExpected string
wantActual string
}{
{
name: "match",
wantName: "requirements.txt",
wantSHA: "abc123",
gotName: "requirements.txt",
gotSHA: "abc123",
wantOk: true,
wantExpected: "requirements.txt:abc123",
wantActual: "requirements.txt:abc123",
},
{
name: "name mismatch",
wantName: "requirements.txt",
wantSHA: "abc123",
gotName: "Pipfile",
gotSHA: "abc123",
wantOk: false,
wantExpected: "requirements.txt:abc123",
wantActual: "Pipfile:abc123",
},
{
name: "sha mismatch",
wantName: "requirements.txt",
wantSHA: "abc123",
gotName: "requirements.txt",
gotSHA: "def456",
wantOk: false,
wantExpected: "requirements.txt:abc123",
wantActual: "requirements.txt:def456",
},
{
name: "want empty",
wantName: "",
wantSHA: "",
gotName: "requirements.txt",
gotSHA: "abc123",
wantOk: true,
wantExpected: "",
wantActual: "",
},
{
name: "got empty",
wantName: "requirements.txt",
wantSHA: "abc123",
gotName: "",
gotSHA: "",
wantOk: true,
wantExpected: "",
wantActual: "",
},
{
name: "both empty",
wantName: "",
wantSHA: "",
gotName: "",
gotSHA: "",
wantOk: true,
wantExpected: "",
wantActual: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.ValidateDepsProvenance(tt.wantName, tt.wantSHA, tt.gotName, tt.gotSHA)
if got.OK != tt.wantOk {
t.Errorf("ValidateDepsProvenance() OK = %v, want %v", got.OK, tt.wantOk)
}
if got.Expected != tt.wantExpected {
t.Errorf("ValidateDepsProvenance() Expected = %q, want %q", got.Expected, tt.wantExpected)
}
if got.Actual != tt.wantActual {
t.Errorf("ValidateDepsProvenance() Actual = %q, want %q", got.Actual, tt.wantActual)
}
})
}
}
func TestValidateSnapshotID(t *testing.T) {
tests := []struct {
name string
wantID string
gotID string
wantOk bool
wantExpected string
wantActual string
}{
{
name: "match",
wantID: "snap-123",
gotID: "snap-123",
wantOk: true,
wantExpected: "snap-123",
wantActual: "snap-123",
},
{
name: "mismatch",
wantID: "snap-123",
gotID: "snap-456",
wantOk: false,
wantExpected: "snap-123",
wantActual: "snap-456",
},
{
name: "want empty",
wantID: "",
gotID: "snap-123",
wantOk: true,
wantExpected: "",
wantActual: "snap-123",
},
{
name: "got empty",
wantID: "snap-123",
gotID: "",
wantOk: true,
wantExpected: "snap-123",
wantActual: "",
},
{
name: "both empty",
wantID: "",
gotID: "",
wantOk: true,
wantExpected: "",
wantActual: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.ValidateSnapshotID(tt.wantID, tt.gotID)
if got.OK != tt.wantOk {
t.Errorf("ValidateSnapshotID() OK = %v, want %v", got.OK, tt.wantOk)
}
if got.Expected != tt.wantExpected {
t.Errorf("ValidateSnapshotID() Expected = %q, want %q", got.Expected, tt.wantExpected)
}
if got.Actual != tt.wantActual {
t.Errorf("ValidateSnapshotID() Actual = %q, want %q", got.Actual, tt.wantActual)
}
})
}
}
func TestValidateSnapshotSHA(t *testing.T) {
tests := []struct {
name string
wantSHA string
gotSHA string
wantOk bool
wantExpected string
wantActual string
}{
{
name: "match",
wantSHA: "sha256:abc123",
gotSHA: "sha256:abc123",
wantOk: true,
wantExpected: "sha256:abc123",
wantActual: "sha256:abc123",
},
{
name: "mismatch",
wantSHA: "sha256:abc123",
gotSHA: "sha256:def456",
wantOk: false,
wantExpected: "sha256:abc123",
wantActual: "sha256:def456",
},
{
name: "want empty",
wantSHA: "",
gotSHA: "sha256:abc123",
wantOk: true,
wantExpected: "",
wantActual: "sha256:abc123",
},
{
name: "got empty",
wantSHA: "sha256:abc123",
gotSHA: "",
wantOk: true,
wantExpected: "sha256:abc123",
wantActual: "",
},
{
name: "both empty",
wantSHA: "",
gotSHA: "",
wantOk: true,
wantExpected: "",
wantActual: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := helpers.ValidateSnapshotSHA(tt.wantSHA, tt.gotSHA)
if got.OK != tt.wantOk {
t.Errorf("ValidateSnapshotSHA() OK = %v, want %v", got.OK, tt.wantOk)
}
if got.Expected != tt.wantExpected {
t.Errorf("ValidateSnapshotSHA() Expected = %q, want %q", got.Expected, tt.wantExpected)
}
if got.Actual != tt.wantActual {
t.Errorf("ValidateSnapshotSHA() Actual = %q, want %q", got.Actual, tt.wantActual)
}
})
}
}