test: add comprehensive test coverage and command improvements
- Add logs and debug end-to-end tests - Add test helper utilities - Improve test fixtures and templates - Update API server and config lint commands - Add multi-user database initialization
This commit is contained in:
parent
b05470b30a
commit
7305e2bc21
22 changed files with 3643 additions and 7 deletions
125
tests/benchmarks/artifact_scanner_bench_test.go
Normal file
125
tests/benchmarks/artifact_scanner_bench_test.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
// BenchmarkArtifactScanGo profiles Go filepath.WalkDir implementation
|
||||
func BenchmarkArtifactScanGo(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
|
||||
// Create test artifact structure
|
||||
createTestArtifacts(b, tmpDir, 100)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := worker.ScanArtifacts(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkArtifactScanNative profiles C++ platform-optimized traversal
|
||||
// Uses: fts on BSD, getdents64 on Linux, getattrlistbulk on macOS
|
||||
func BenchmarkArtifactScanNative(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
|
||||
// Create test artifact structure
|
||||
createTestArtifacts(b, tmpDir, 100)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := worker.ScanArtifactsNative(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkArtifactScanLarge tests with many files
|
||||
func BenchmarkArtifactScanLarge(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
|
||||
// Create 1000 test files
|
||||
createTestArtifacts(b, tmpDir, 1000)
|
||||
|
||||
b.Run("Go", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := worker.ScanArtifacts(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Native", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := worker.ScanArtifactsNative(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// createTestArtifacts creates a directory structure with test files
|
||||
func createTestArtifacts(b testing.TB, root string, count int) {
|
||||
b.Helper()
|
||||
|
||||
// Create nested directories
|
||||
dirs := []string{
|
||||
"outputs",
|
||||
"outputs/models",
|
||||
"outputs/checkpoints",
|
||||
"logs",
|
||||
"data",
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
if err := os.MkdirAll(filepath.Join(root, dir), 0750); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create test files
|
||||
for i := 0; i < count; i++ {
|
||||
var path string
|
||||
switch i % 5 {
|
||||
case 0:
|
||||
path = filepath.Join(root, "outputs", "model_"+string(rune('0'+i%10))+".pt")
|
||||
case 1:
|
||||
path = filepath.Join(root, "outputs", "models", "checkpoint_"+string(rune('0'+i%10))+".ckpt")
|
||||
case 2:
|
||||
path = filepath.Join(root, "outputs", "checkpoints", "epoch_"+string(rune('0'+i%10))+".pt")
|
||||
case 3:
|
||||
path = filepath.Join(root, "logs", "train_"+string(rune('0'+i%10))+".log")
|
||||
case 4:
|
||||
path = filepath.Join(root, "data", "batch_"+string(rune('0'+i%10))+".npy")
|
||||
}
|
||||
|
||||
data := make([]byte, 1024*(i%10+1)) // Varying sizes 1KB-10KB
|
||||
for j := range data {
|
||||
data[j] = byte(i + j%256)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0640); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create files that should be excluded
|
||||
os.WriteFile(filepath.Join(root, "run_manifest.json"), []byte("{}"), 0640)
|
||||
os.MkdirAll(filepath.Join(root, "code"), 0750)
|
||||
os.WriteFile(filepath.Join(root, "code", "script.py"), []byte("# test"), 0640)
|
||||
}
|
||||
175
tests/benchmarks/config_parsing_bench_test.go
Normal file
175
tests/benchmarks/config_parsing_bench_test.go
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Sample server config YAML for benchmarking
|
||||
const sampleServerConfig = `
|
||||
base_path: /api/v1
|
||||
data_dir: /data/fetchml
|
||||
auth:
|
||||
type: jwt
|
||||
secret: "super-secret-key-for-benchmarking-only"
|
||||
token_ttl: 3600
|
||||
server:
|
||||
address: :8080
|
||||
tls:
|
||||
enabled: true
|
||||
cert_file: /etc/ssl/certs/server.crt
|
||||
key_file: /etc/ssl/private/server.key
|
||||
security:
|
||||
max_request_size: 1048576
|
||||
rate_limit: 100
|
||||
cors_origins:
|
||||
- https://app.fetchml.com
|
||||
- https://admin.fetchml.com
|
||||
queue:
|
||||
backend: redis
|
||||
sqlite_path: /data/queue.db
|
||||
filesystem_path: /data/queue
|
||||
fallback_to_filesystem: true
|
||||
redis:
|
||||
addr: localhost:6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 50
|
||||
database:
|
||||
driver: postgres
|
||||
dsn: postgres://user:pass@localhost/fetchml?sslmode=disable
|
||||
max_connections: 100
|
||||
logging:
|
||||
level: info
|
||||
format: json
|
||||
output: stdout
|
||||
resources:
|
||||
max_cpu_per_task: 8
|
||||
max_memory_per_task: 32
|
||||
max_gpu_per_task: 1
|
||||
monitoring:
|
||||
enabled: true
|
||||
prometheus_port: 9090
|
||||
`
|
||||
|
||||
type BenchmarkServerConfig struct {
|
||||
BasePath string `yaml:"base_path"`
|
||||
DataDir string `yaml:"data_dir"`
|
||||
Auth BenchmarkAuthConfig `yaml:"auth"`
|
||||
Server BenchmarkServerSection `yaml:"server"`
|
||||
Security BenchmarkSecurityConfig `yaml:"security"`
|
||||
Queue BenchmarkQueueConfig `yaml:"queue"`
|
||||
Redis BenchmarkRedisConfig `yaml:"redis"`
|
||||
Database BenchmarkDatabaseConfig `yaml:"database"`
|
||||
Logging BenchmarkLoggingConfig `yaml:"logging"`
|
||||
Resources BenchmarkResourceConfig `yaml:"resources"`
|
||||
Monitoring BenchmarkMonitoringConfig `yaml:"monitoring"`
|
||||
}
|
||||
|
||||
type BenchmarkAuthConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
Secret string `yaml:"secret"`
|
||||
TokenTTL int `yaml:"token_ttl"`
|
||||
}
|
||||
|
||||
type BenchmarkServerSection struct {
|
||||
Address string `yaml:"address"`
|
||||
TLS BenchmarkTLSConfig `yaml:"tls"`
|
||||
}
|
||||
|
||||
type BenchmarkTLSConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
CertFile string `yaml:"cert_file"`
|
||||
KeyFile string `yaml:"key_file"`
|
||||
}
|
||||
|
||||
type BenchmarkSecurityConfig struct {
|
||||
MaxRequestSize int `yaml:"max_request_size"`
|
||||
RateLimit int `yaml:"rate_limit"`
|
||||
CORSOrigins []string `yaml:"cors_origins"`
|
||||
}
|
||||
|
||||
type BenchmarkQueueConfig struct {
|
||||
Backend string `yaml:"backend"`
|
||||
SQLitePath string `yaml:"sqlite_path"`
|
||||
FilesystemPath string `yaml:"filesystem_path"`
|
||||
FallbackToFilesystem bool `yaml:"fallback_to_filesystem"`
|
||||
}
|
||||
|
||||
type BenchmarkRedisConfig struct {
|
||||
Addr string `yaml:"addr"`
|
||||
Password string `yaml:"password"`
|
||||
DB int `yaml:"db"`
|
||||
PoolSize int `yaml:"pool_size"`
|
||||
}
|
||||
|
||||
type BenchmarkDatabaseConfig struct {
|
||||
Driver string `yaml:"driver"`
|
||||
DSN string `yaml:"dsn"`
|
||||
MaxConnections int `yaml:"max_connections"`
|
||||
}
|
||||
|
||||
type BenchmarkLoggingConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
Output string `yaml:"output"`
|
||||
}
|
||||
|
||||
type BenchmarkResourceConfig struct {
|
||||
MaxCPUPerTask int `yaml:"max_cpu_per_task"`
|
||||
MaxMemoryPerTask int `yaml:"max_memory_per_task"`
|
||||
MaxGPUPerTask int `yaml:"max_gpu_per_task"`
|
||||
}
|
||||
|
||||
type BenchmarkMonitoringConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
PrometheusPort int `yaml:"prometheus_port"`
|
||||
}
|
||||
|
||||
// BenchmarkConfigYAMLUnmarshal profiles YAML config parsing
|
||||
// Tier 3 C++ candidate: Fast binary config format
|
||||
func BenchmarkConfigYAMLUnmarshal(b *testing.B) {
|
||||
data := []byte(sampleServerConfig)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var cfg BenchmarkServerConfig
|
||||
err := yaml.Unmarshal(data, &cfg)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConfigYAMLUnmarshalLarge profiles large config parsing
|
||||
func BenchmarkConfigYAMLUnmarshalLarge(b *testing.B) {
|
||||
// Create larger config with more nested data
|
||||
largeConfig := sampleServerConfig + `
|
||||
extra_section:
|
||||
items:
|
||||
- id: item1
|
||||
name: "First Item"
|
||||
value: 100
|
||||
- id: item2
|
||||
name: "Second Item"
|
||||
value: 200
|
||||
- id: item3
|
||||
name: "Third Item"
|
||||
value: 300
|
||||
`
|
||||
data := []byte(largeConfig)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var cfg BenchmarkServerConfig
|
||||
err := yaml.Unmarshal(data, &cfg)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
185
tests/benchmarks/json_serialization_bench_test.go
Normal file
185
tests/benchmarks/json_serialization_bench_test.go
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// BenchmarkTaskJSONMarshal profiles Task serialization to JSON
|
||||
// Tier 3 C++ candidate: Zero-copy binary serialization
|
||||
func BenchmarkTaskJSONMarshal(b *testing.B) {
|
||||
task := createTestTask()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTaskJSONUnmarshal profiles Task deserialization from JSON
|
||||
// Tier 3 C++ candidate: Memory-mapped binary format
|
||||
func BenchmarkTaskJSONUnmarshal(b *testing.B) {
|
||||
task := createTestTask()
|
||||
data, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var t queue.Task
|
||||
err := json.Unmarshal(data, &t)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTaskJSONRoundTrip profiles full serialize/deserialize cycle
|
||||
func BenchmarkTaskJSONRoundTrip(b *testing.B) {
|
||||
task := createTestTask()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
data, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
var t queue.Task
|
||||
err = json.Unmarshal(data, &t)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPrewarmStateJSONMarshal profiles PrewarmState serialization
|
||||
// Used in worker prewarm coordination
|
||||
func BenchmarkPrewarmStateJSONMarshal(b *testing.B) {
|
||||
state := queue.PrewarmState{
|
||||
WorkerID: "worker-12345",
|
||||
TaskID: "task-67890",
|
||||
SnapshotID: "snap-abc123",
|
||||
StartedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Phase: "running",
|
||||
EnvImage: "fetchml/python:3.11",
|
||||
DatasetCnt: 3,
|
||||
EnvHit: 5,
|
||||
EnvMiss: 1,
|
||||
EnvBuilt: 2,
|
||||
EnvTimeNs: 1500000000,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPrewarmStateJSONUnmarshal profiles PrewarmState deserialization
|
||||
func BenchmarkPrewarmStateJSONUnmarshal(b *testing.B) {
|
||||
state := queue.PrewarmState{
|
||||
WorkerID: "worker-12345",
|
||||
TaskID: "task-67890",
|
||||
SnapshotID: "snap-abc123",
|
||||
StartedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Phase: "running",
|
||||
EnvImage: "fetchml/python:3.11",
|
||||
DatasetCnt: 3,
|
||||
EnvHit: 5,
|
||||
EnvMiss: 1,
|
||||
EnvBuilt: 2,
|
||||
EnvTimeNs: 1500000000,
|
||||
}
|
||||
data, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s queue.PrewarmState
|
||||
err := json.Unmarshal(data, &s)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTaskBatchJSONMarshal profiles batch task serialization
|
||||
// Critical for queue index operations with many tasks
|
||||
func BenchmarkTaskBatchJSONMarshal(b *testing.B) {
|
||||
tasks := make([]queue.Task, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
tasks[i] = createTestTaskWithID(i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, task := range tasks {
|
||||
_, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create a test task
|
||||
func createTestTask() queue.Task {
|
||||
return createTestTaskWithID(0)
|
||||
}
|
||||
|
||||
func createTestTaskWithID(id int) queue.Task {
|
||||
now := time.Now()
|
||||
return queue.Task{
|
||||
ID: fmt.Sprintf("task-550e8400-e29b-41d4-a716-44665544%03d", id),
|
||||
JobName: "ml-training-job",
|
||||
Args: "--epochs=100 --batch_size=32 --learning_rate=0.001",
|
||||
Status: "queued",
|
||||
Priority: 1000,
|
||||
CreatedAt: now,
|
||||
DatasetSpecs: []queue.DatasetSpec{
|
||||
{Name: "training-data", Version: "v1.2.3", Checksum: "sha256:abc123", URI: "s3://bucket/data/train"},
|
||||
{Name: "validation-data", Version: "v1.0.0", Checksum: "sha256:def456", URI: "s3://bucket/data/val"},
|
||||
},
|
||||
Datasets: []string{"dataset1", "dataset2"},
|
||||
Metadata: map[string]string{
|
||||
"experiment_id": "exp-12345",
|
||||
"user_email": "user@example.com",
|
||||
"model_type": "transformer",
|
||||
},
|
||||
CPU: 4,
|
||||
MemoryGB: 16,
|
||||
GPU: 1,
|
||||
GPUMemory: "24GB",
|
||||
UserID: "user-550e8400-e29b-41d4-a716",
|
||||
Username: "ml_researcher",
|
||||
CreatedBy: "admin",
|
||||
MaxRetries: 3,
|
||||
}
|
||||
}
|
||||
278
tests/benchmarks/jupyter_service_bench_test.go
Normal file
278
tests/benchmarks/jupyter_service_bench_test.go
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BenchmarkResolveWorkspacePath profiles path canonicalization hot path
|
||||
// Tier 2 C++ candidate: SIMD string operations for path validation
|
||||
func BenchmarkResolveWorkspacePath(b *testing.B) {
|
||||
testPaths := []string{
|
||||
"my-workspace",
|
||||
"./relative/path",
|
||||
"/absolute/path/to/workspace",
|
||||
" trimmed-path ",
|
||||
"deep/nested/workspace/name",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
path := testPaths[i%len(testPaths)]
|
||||
// Simulate the resolveWorkspacePath logic inline
|
||||
ws := strings.TrimSpace(path)
|
||||
clean := filepath.Clean(ws)
|
||||
if !filepath.IsAbs(clean) {
|
||||
clean = filepath.Join("/data/active/workspaces", clean)
|
||||
}
|
||||
_ = clean
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStringTrimSpace profiles strings.TrimSpace usage
|
||||
// Heavy usage in service_manager.go (55+ calls)
|
||||
func BenchmarkStringTrimSpace(b *testing.B) {
|
||||
testInputs := []string{
|
||||
" package-name ",
|
||||
"\t\n trimmed \r\n",
|
||||
"normal",
|
||||
" ",
|
||||
"",
|
||||
"pkg1==1.0.0",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = strings.TrimSpace(testInputs[i%len(testInputs)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStringSplit profiles strings.Split for package parsing
|
||||
// Used in parsePipList, parseCondaList, splitPackageList
|
||||
func BenchmarkStringSplit(b *testing.B) {
|
||||
testInputs := []string{
|
||||
"numpy,pandas,scikit-learn,tensorflow",
|
||||
"package==1.0.0",
|
||||
"single",
|
||||
"",
|
||||
"a,b,c,d,e,f,g,h,i,j",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = strings.Split(testInputs[i%len(testInputs)], ",")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStringHasPrefix profiles strings.HasPrefix
|
||||
// Used for path traversal detection and comment filtering
|
||||
func BenchmarkStringHasPrefix(b *testing.B) {
|
||||
testInputs := []string{
|
||||
"../traversal",
|
||||
"normal/path",
|
||||
"# comment",
|
||||
"package==1.0",
|
||||
"..",
|
||||
}
|
||||
prefixes := []string{"..", "#", "package"}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = strings.HasPrefix(testInputs[i%len(testInputs)], prefixes[i%len(prefixes)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStringEqualFold profiles case-insensitive comparison
|
||||
// Used for blocked package matching
|
||||
func BenchmarkStringEqualFold(b *testing.B) {
|
||||
testPairs := []struct {
|
||||
a, b string
|
||||
}{
|
||||
{"numpy", "NUMPY"},
|
||||
{"requests", "requests"},
|
||||
{"TensorFlow", "tensorflow"},
|
||||
{"scikit-learn", "SCIKIT-LEARN"},
|
||||
{"off", "OFF"},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pair := testPairs[i%len(testPairs)]
|
||||
_ = strings.EqualFold(pair.a, pair.b)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFilepathClean profiles path canonicalization
|
||||
// Used in resolveWorkspacePath and other path operations
|
||||
func BenchmarkFilepathClean(b *testing.B) {
|
||||
testPaths := []string{
|
||||
"./relative/../path",
|
||||
"/absolute//double//slashes",
|
||||
"../../../etc/passwd",
|
||||
"normal/path",
|
||||
".",
|
||||
"..",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = filepath.Clean(testPaths[i%len(testPaths)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFilepathJoin profiles path joining
|
||||
// Used extensively for building workspace, trash, state paths
|
||||
func BenchmarkFilepathJoin(b *testing.B) {
|
||||
components := [][]string{
|
||||
{"/data", "active", "workspaces"},
|
||||
{"/tmp", "trash", "jupyter"},
|
||||
{"state", "fetch_ml_jupyter_workspaces.json"},
|
||||
{"workspace", "subdir", "file.txt"},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
parts := components[i%len(components)]
|
||||
_ = filepath.Join(parts...)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStrconvAtoi profiles string to int conversion
|
||||
// Used for port parsing, resource limits
|
||||
func BenchmarkStrconvAtoi(b *testing.B) {
|
||||
testInputs := []string{
|
||||
"8888",
|
||||
"0",
|
||||
"65535",
|
||||
"-1",
|
||||
"999999",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = strconv.Atoi(testInputs[i%len(testInputs)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTimeFormat profiles timestamp formatting
|
||||
// Used for trash destination naming
|
||||
func BenchmarkTimeFormat(b *testing.B) {
|
||||
now := time.Now().UTC()
|
||||
format := "20060102_150405"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = now.Format(format)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSprintfConcat profiles string concatenation
|
||||
// Used for building destination names
|
||||
func BenchmarkSprintfConcat(b *testing.B) {
|
||||
names := []string{"workspace1", "my-project", "test-ws"}
|
||||
timestamps := []string{"20240115_143022", "20240212_183045", "20240301_120000"}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = names[i%len(names)] + "_" + timestamps[i%len(timestamps)]
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPackageListParsing profiles full package list parsing pipeline
|
||||
// Combines Split, TrimSpace, and filtering - mirrors splitPackageList
|
||||
func BenchmarkPackageListParsing(b *testing.B) {
|
||||
testInputs := []string{
|
||||
"numpy, pandas, scikit-learn, tensorflow",
|
||||
" requests , urllib3 , certifi ",
|
||||
"single-package",
|
||||
"",
|
||||
"a, b, c, d, e, f, g",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
input := testInputs[i%len(testInputs)]
|
||||
parts := strings.Split(input, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEnvLookup profiles environment variable lookups
|
||||
// Used for FETCHML_JUPYTER_* configuration
|
||||
func BenchmarkEnvLookup(b *testing.B) {
|
||||
// Pre-set test env vars
|
||||
os.Setenv("TEST_JUPYTER_STATE_DIR", "/tmp/jupyter-state")
|
||||
os.Setenv("TEST_JUPYTER_WORKSPACE_BASE", "/tmp/workspaces")
|
||||
|
||||
keys := []string{
|
||||
"TEST_JUPYTER_STATE_DIR",
|
||||
"TEST_JUPYTER_WORKSPACE_BASE",
|
||||
"NONEXISTENT_VAR",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = os.LookupEnv(keys[i%len(keys)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCombinedJupyterHotPath profiles typical service manager operation
|
||||
// Combines multiple string/path operations as they occur in real usage
|
||||
func BenchmarkCombinedJupyterHotPath(b *testing.B) {
|
||||
testWorkspaces := []string{
|
||||
" my-project ",
|
||||
"./relative-ws",
|
||||
"/absolute/path/to/workspace",
|
||||
"deep/nested/name",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate resolveWorkspacePath + path building
|
||||
ws := strings.TrimSpace(testWorkspaces[i%len(testWorkspaces)])
|
||||
clean := filepath.Clean(ws)
|
||||
if !filepath.IsAbs(clean) {
|
||||
clean = filepath.Join("/data/active/workspaces", clean)
|
||||
}
|
||||
// Simulate trash path building
|
||||
ts := time.Now().UTC().Format("20060102_150405")
|
||||
destName := ws + "_" + ts
|
||||
_ = filepath.Join("/tmp/trash/jupyter", destName)
|
||||
}
|
||||
}
|
||||
84
tests/benchmarks/log_sanitize_bench_test.go
Normal file
84
tests/benchmarks/log_sanitize_bench_test.go
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// BenchmarkLogSanitizeMessage profiles log message sanitization.
|
||||
// This is a Tier 1 C++ candidate because:
|
||||
// - Regex matching is CPU-intensive
|
||||
// - High volume log pipelines process thousands of messages/sec
|
||||
// - C++ can use Hyperscan/RE2 for parallel regex matching
|
||||
// Expected speedup: 3-5x for high-volume logging
|
||||
func BenchmarkLogSanitizeMessage(b *testing.B) {
|
||||
// Test messages with various sensitive data patterns
|
||||
messages := []string{
|
||||
"User login successful with api_key=abc123def45678901234567890abcdef",
|
||||
"JWT token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0C12LVE",
|
||||
"Redis connection: redis://:secretpassword123@localhost:6379/0",
|
||||
"User admin password=supersecret123 trying to access resource",
|
||||
"Normal log message without any sensitive data to process",
|
||||
"API call with key=fedcba9876543210fedcba9876543210 and secret=shh123",
|
||||
"Connection string: redis://:another_secret@redis.example.com:6380",
|
||||
"Authentication token: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.test.signature",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
msg := messages[i%len(messages)]
|
||||
_ = logging.SanitizeLogMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLogSanitizeArgs profiles structured log argument sanitization.
|
||||
// This processes key-value pairs looking for sensitive field names.
|
||||
func BenchmarkLogSanitizeArgs(b *testing.B) {
|
||||
// Simulate typical structured log arguments
|
||||
args := []any{
|
||||
"user_id", "user123",
|
||||
"password", "secret123",
|
||||
"api_key", "abcdef1234567890",
|
||||
"action", "login",
|
||||
"secret_token", "eyJhbGci...",
|
||||
"request_id", "req-12345",
|
||||
"database_url", "redis://:password@localhost:6379",
|
||||
"timestamp", "2024-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = logging.SanitizeArgs(args)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLogSanitizeHighVolume simulates high-throughput logging scenario
|
||||
// with many messages per second (e.g., 10K+ messages/sec).
|
||||
func BenchmarkLogSanitizeHighVolume(b *testing.B) {
|
||||
// Mix of message types
|
||||
testMessages := []string{
|
||||
"API request: POST /api/v1/jobs with api_key=abcdef1234567890abcdef1234567890",
|
||||
"User user123 authenticated with token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature",
|
||||
"Database connection established to redis://:redact_me@localhost:6379",
|
||||
"Job execution started for job_name=test_job",
|
||||
"Error processing request: password=wrong_secret provided",
|
||||
"Metrics: cpu=45%, memory=2.5GB, gpu=0%",
|
||||
"Config loaded: secret_key=hidden_value123",
|
||||
"Webhook received with authorization=Bearer eyJ0eXAiOiJKV1Qi.test.sig",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate batch processing of multiple messages
|
||||
for _, msg := range testMessages {
|
||||
_ = logging.SanitizeLogMessage(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
40
tests/benchmarks/native_queue_basic_test.go
Normal file
40
tests/benchmarks/native_queue_basic_test.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// BenchmarkNativeQueueBasic tests basic native queue operations
|
||||
func BenchmarkNativeQueueBasic(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
|
||||
// Only run if native libs available
|
||||
if !queue.UseNativeQueue {
|
||||
b.Skip("Native queue not enabled (set FETCHML_NATIVE_LIBS=1)")
|
||||
}
|
||||
|
||||
q, err := queue.NewNativeQueue(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer q.Close()
|
||||
|
||||
// Test single add
|
||||
task := &queue.Task{
|
||||
ID: "test-1",
|
||||
JobName: "test-job",
|
||||
Priority: 100,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
task.ID = "test-" + string(rune('0'+i%10))
|
||||
if err := q.AddTask(task); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
108
tests/benchmarks/native_queue_bench_test.go
Normal file
108
tests/benchmarks/native_queue_bench_test.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// BenchmarkNativeQueueRebuildIndex profiles the native binary queue index.
|
||||
// Tier 1 C++ candidate: binary format vs JSON
|
||||
// Expected: 5x speedup, 99% allocation reduction
|
||||
func BenchmarkNativeQueueRebuildIndex(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
q, err := queue.NewNativeQueue(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer q.Close()
|
||||
|
||||
// Seed with tasks
|
||||
for i := 0; i < 100; i++ {
|
||||
task := &queue.Task{
|
||||
ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)),
|
||||
JobName: "job-" + string(rune('0'+i/10)),
|
||||
Priority: int64(100 - i),
|
||||
}
|
||||
if err := q.AddTask(task); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// Benchmark just the add (native uses binary index, no JSON rebuild)
|
||||
for i := 0; i < b.N; i++ {
|
||||
task := &queue.Task{
|
||||
ID: "bench-task-" + string(rune('0'+i%10)),
|
||||
JobName: "bench-job",
|
||||
Priority: int64(i),
|
||||
}
|
||||
if err := q.AddTask(task); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNativeQueueClaimNext profiles task claiming from binary heap
|
||||
func BenchmarkNativeQueueClaimNext(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
q, err := queue.NewNativeQueue(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer q.Close()
|
||||
|
||||
// Seed with tasks
|
||||
for i := 0; i < 100; i++ {
|
||||
task := &queue.Task{
|
||||
ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)),
|
||||
JobName: "job-" + string(rune('0'+i/10)),
|
||||
Priority: int64(100 - i),
|
||||
}
|
||||
if err := q.AddTask(task); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Binary heap pop - no JSON parsing
|
||||
_, _ = q.PeekNextTask()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNativeQueueGetAllTasks profiles full task scan from binary index
|
||||
func BenchmarkNativeQueueGetAllTasks(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
q, err := queue.NewNativeQueue(tmpDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer q.Close()
|
||||
|
||||
// Seed with tasks
|
||||
for i := 0; i < 100; i++ {
|
||||
task := &queue.Task{
|
||||
ID: "task-" + string(rune('0'+i/10)) + string(rune('0'+i%10)),
|
||||
JobName: "job-" + string(rune('0'+i/10)),
|
||||
Priority: int64(100 - i),
|
||||
}
|
||||
if err := q.AddTask(task); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := q.GetAllTasks()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
189
tests/benchmarks/streaming_io_bench_test.go
Normal file
189
tests/benchmarks/streaming_io_bench_test.go
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
package benchmarks
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
// BenchmarkExtractTarGzGo profiles Go sequential tar.gz extraction
|
||||
func BenchmarkExtractTarGzGo(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
archivePath := createStreamingTestArchive(b, tmpDir, 100, 1024) // 100 files, 1KB each
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
dstDir := filepath.Join(tmpDir, "extract_go_"+string(rune('0'+i%10)))
|
||||
if err := os.MkdirAll(dstDir, 0750); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := worker.ExtractTarGz(archivePath, dstDir); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkExtractTarGzNative profiles C++ parallel decompression
|
||||
// Uses: mmap + thread pool + O_DIRECT for large files
|
||||
func BenchmarkExtractTarGzNative(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
archivePath := createStreamingTestArchive(b, tmpDir, 100, 1024)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
dstDir := filepath.Join(tmpDir, "extract_native_"+string(rune('0'+i%10)))
|
||||
if err := os.MkdirAll(dstDir, 0750); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := worker.ExtractTarGzNative(archivePath, dstDir); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkExtractTarGzSizes tests different archive sizes
|
||||
func BenchmarkExtractTarGzSizes(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
|
||||
// Small: 10 files, 1KB each (~10KB compressed)
|
||||
b.Run("Small", func(b *testing.B) {
|
||||
archivePath := createStreamingTestArchive(b, tmpDir, 10, 1024)
|
||||
benchmarkBoth(b, archivePath, tmpDir)
|
||||
})
|
||||
|
||||
// Medium: 100 files, 10KB each (~1MB compressed)
|
||||
b.Run("Medium", func(b *testing.B) {
|
||||
archivePath := createStreamingTestArchive(b, tmpDir, 100, 10240)
|
||||
benchmarkBoth(b, archivePath, tmpDir)
|
||||
})
|
||||
|
||||
// Large: 50 files, 100KB each (~5MB compressed)
|
||||
b.Run("Large", func(b *testing.B) {
|
||||
archivePath := createStreamingTestArchive(b, tmpDir, 50, 102400)
|
||||
benchmarkBoth(b, archivePath, tmpDir)
|
||||
})
|
||||
}
|
||||
|
||||
func benchmarkBoth(b *testing.B, archivePath, tmpDir string) {
|
||||
b.Run("Go", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
dstDir := filepath.Join(tmpDir, "go_"+string(rune('0'+i%10)))
|
||||
if err := os.MkdirAll(dstDir, 0750); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := worker.ExtractTarGz(archivePath, dstDir); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Native", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
dstDir := filepath.Join(tmpDir, "native_"+string(rune('0'+i%10)))
|
||||
if err := os.MkdirAll(dstDir, 0750); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := worker.ExtractTarGzNative(archivePath, dstDir); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// createStreamingTestArchive creates a tar.gz archive with test files for streaming benchmarks
|
||||
func createStreamingTestArchive(b testing.TB, tmpDir string, numFiles, fileSize int) string {
|
||||
b.Helper()
|
||||
|
||||
// Create data directory with test files
|
||||
dataDir := filepath.Join(tmpDir, "data_"+string(rune('0'+numFiles/10)))
|
||||
if err := os.MkdirAll(dataDir, 0750); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < numFiles; i++ {
|
||||
subdir := filepath.Join(dataDir, "subdir_"+string(rune('0'+i%5)))
|
||||
if err := os.MkdirAll(subdir, 0750); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
filename := filepath.Join(subdir, "file_"+string(rune('0'+i/10))+".bin")
|
||||
data := make([]byte, fileSize)
|
||||
for j := range data {
|
||||
data[j] = byte((i + j) % 256)
|
||||
}
|
||||
if err := os.WriteFile(filename, data, 0640); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create tar.gz archive
|
||||
archivePath := filepath.Join(tmpDir, "test_"+string(rune('0'+numFiles/10))+".tar.gz")
|
||||
if err := createTarGzFromDir(dataDir, archivePath); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
return archivePath
|
||||
}
|
||||
|
||||
// createTarGzFromDir creates a tar.gz archive from a directory
|
||||
func createTarGzFromDir(srcDir, dstPath string) error {
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(srcDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hdr, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hdr.Name = rel
|
||||
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tw.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tw.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := gw.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(dstPath, buf.Bytes(), 0640)
|
||||
}
|
||||
|
|
@ -14,6 +14,14 @@ import (
|
|||
|
||||
// ChaosTestSuite tests system resilience under various failure conditions
|
||||
func TestChaosTestSuite(t *testing.T) {
|
||||
// Check Redis availability at suite level and warn if not available
|
||||
quickRedis := redis.NewClient(&redis.Options{Addr: "localhost:6379", DB: 6})
|
||||
if err := quickRedis.Ping(context.Background()).Err(); err != nil {
|
||||
t.Logf("WARNING: Redis not available at localhost:6379 - chaos tests will be skipped")
|
||||
t.Logf(" To run these tests, start Redis: redis-server --port 6379")
|
||||
}
|
||||
quickRedis.Close()
|
||||
|
||||
// Tests that intentionally close/corrupt connections get their own resources
|
||||
// to prevent cascading failures to subsequent subtests
|
||||
|
||||
|
|
@ -69,6 +77,7 @@ func TestChaosTestSuite(t *testing.T) {
|
|||
rdb := setupChaosRedis(t)
|
||||
if rdb == nil {
|
||||
t.Skip("Redis not available for chaos tests")
|
||||
return
|
||||
}
|
||||
defer func() { _ = rdb.Close() }()
|
||||
|
||||
|
|
@ -473,6 +482,7 @@ func setupChaosRedis(t *testing.T) *redis.Client {
|
|||
|
||||
ctx := context.Background()
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
t.Logf("Skipping chaos test - Redis not available: %v", err)
|
||||
t.Skipf("Redis not available for chaos tests: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
50
tests/e2e/docker-compose.logs-debug.yml
Normal file
50
tests/e2e/docker-compose.logs-debug.yml
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
---
|
||||
# Docker Compose configuration for logs and debug E2E tests
|
||||
# Simplified version using pre-built golang image with source mount
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "6380:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
api-server:
|
||||
image: golang:1.25-bookworm
|
||||
working_dir: /app
|
||||
command: >
|
||||
sh -c "
|
||||
go build -o api-server ./cmd/api-server/main.go &&
|
||||
./api-server --config /app/configs/api/dev.yaml
|
||||
"
|
||||
ports:
|
||||
- "9102:9101"
|
||||
environment:
|
||||
- LOG_LEVEL=debug
|
||||
- REDIS_ADDR=redis:6379
|
||||
- FETCHML_NATIVE_LIBS=0
|
||||
volumes:
|
||||
- ../../:/app
|
||||
- api-logs:/logs
|
||||
- api-experiments:/data/experiments
|
||||
- api-active:/data/active
|
||||
- go-mod-cache:/go/pkg/mod
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:9101/health"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
start_period: 30s
|
||||
|
||||
volumes:
|
||||
api-logs:
|
||||
api-experiments:
|
||||
api-active:
|
||||
go-mod-cache:
|
||||
590
tests/e2e/logs_debug_e2e_test.go
Normal file
590
tests/e2e/logs_debug_e2e_test.go
Normal file
|
|
@ -0,0 +1,590 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
logsDebugComposeFile = "docker-compose.logs-debug.yml"
|
||||
apiPort = "9102" // Use docker-compose port
|
||||
apiHost = "localhost"
|
||||
)
|
||||
|
||||
// TestLogsDebugE2E tests the logs and debug WebSocket API end-to-end
|
||||
func TestLogsDebugE2E(t *testing.T) {
|
||||
if os.Getenv("FETCH_ML_E2E_LOGS_DEBUG") != "1" {
|
||||
t.Skip("Skipping LogsDebugE2E (set FETCH_ML_E2E_LOGS_DEBUG=1 to enable)")
|
||||
}
|
||||
|
||||
composeFile := filepath.Join(".", logsDebugComposeFile)
|
||||
if _, err := os.Stat(composeFile); os.IsNotExist(err) {
|
||||
t.Skipf("Docker compose file not found: %s", composeFile)
|
||||
}
|
||||
|
||||
// Ensure Docker is available
|
||||
if _, err := exec.LookPath("docker"); err != nil {
|
||||
t.Skip("Docker not found in PATH")
|
||||
}
|
||||
if _, err := exec.LookPath("docker-compose"); err != nil {
|
||||
t.Skip("docker-compose not found in PATH")
|
||||
}
|
||||
|
||||
t.Run("SetupAndLogsAPI", func(t *testing.T) {
|
||||
testLogsAPIWithDockerCompose(t, composeFile)
|
||||
})
|
||||
|
||||
t.Run("SetupAndDebugAPI", func(t *testing.T) {
|
||||
testDebugAPIWithDockerCompose(t, composeFile)
|
||||
})
|
||||
}
|
||||
|
||||
// testLogsAPIWithDockerCompose tests the logs API using Docker Compose
|
||||
func testLogsAPIWithDockerCompose(t *testing.T, composeFile string) {
|
||||
// Cleanup any existing containers
|
||||
cleanupDockerCompose(t, composeFile)
|
||||
|
||||
// Start services
|
||||
t.Log("Starting Docker Compose services for logs API test...")
|
||||
upCmd := exec.CommandContext(context.Background(),
|
||||
"docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "up", "-d", "--build")
|
||||
upOutput, err := upCmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start Docker Compose: %v\nOutput: %s", err, string(upOutput))
|
||||
}
|
||||
|
||||
// Cleanup after test
|
||||
defer func() {
|
||||
cleanupDockerCompose(t, composeFile)
|
||||
}()
|
||||
|
||||
// Wait for services to be healthy
|
||||
t.Log("Waiting for services to be healthy...")
|
||||
if !waitForServicesHealthy(t, "logsdebug-e2e", 90*time.Second) {
|
||||
t.Fatal("Services failed to become healthy")
|
||||
}
|
||||
|
||||
// Add delay to ensure server is fully ready for WebSocket connections
|
||||
t.Log("Waiting additional 3 seconds for WebSocket to be ready...")
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Test logs API
|
||||
testTargetID := "ab"
|
||||
testGetLogsAPI(t, testTargetID)
|
||||
testStreamLogsAPI(t, testTargetID)
|
||||
}
|
||||
|
||||
// testDebugAPIWithDockerCompose tests the debug API using Docker Compose
|
||||
func testDebugAPIWithDockerCompose(t *testing.T, composeFile string) {
|
||||
// Cleanup any existing containers
|
||||
cleanupDockerCompose(t, composeFile)
|
||||
|
||||
// Start services
|
||||
t.Log("Starting Docker Compose services for debug API test...")
|
||||
upCmd := exec.CommandContext(context.Background(),
|
||||
"docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "up", "-d", "--build")
|
||||
upOutput, err := upCmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start Docker Compose: %v\nOutput: %s", err, string(upOutput))
|
||||
}
|
||||
|
||||
// Cleanup after test
|
||||
defer func() {
|
||||
cleanupDockerCompose(t, composeFile)
|
||||
}()
|
||||
|
||||
// Wait for services to be healthy
|
||||
t.Log("Waiting for services to be healthy...")
|
||||
if !waitForServicesHealthy(t, "logsdebug-e2e", 90*time.Second) {
|
||||
t.Fatal("Services failed to become healthy")
|
||||
}
|
||||
|
||||
// Add delay to ensure server is fully ready for WebSocket connections
|
||||
t.Log("Waiting additional 3 seconds for WebSocket to be ready...")
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
testTargetID := "test-job-456"
|
||||
testAttachDebugAPI(t, testTargetID, "interactive")
|
||||
testAttachDebugAPI(t, testTargetID, "gdb")
|
||||
testAttachDebugAPI(t, testTargetID, "pdb")
|
||||
}
|
||||
|
||||
// connectWebSocketWithRetry attempts to connect with exponential backoff
|
||||
func connectWebSocketWithRetry(t *testing.T, wsURL string, maxRetries int) (*websocket.Conn, *http.Response, error) {
|
||||
var conn *websocket.Conn
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
// Create dialer with nil headers (like existing working test)
|
||||
dialer := websocket.DefaultDialer
|
||||
dialer.HandshakeTimeout = 10 * time.Second
|
||||
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
conn, resp, err = dialer.Dial(wsURL, nil)
|
||||
if err == nil {
|
||||
return conn, resp, nil
|
||||
}
|
||||
|
||||
// Log response details for debugging
|
||||
if resp != nil {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
t.Logf("WebSocket connection attempt %d/%d failed: status=%d, body=%s", i+1, maxRetries, resp.StatusCode, string(body))
|
||||
} else {
|
||||
t.Logf("WebSocket connection attempt %d/%d failed: %v", i+1, maxRetries, err)
|
||||
}
|
||||
|
||||
if i < maxRetries-1 {
|
||||
delay := time.Duration(i+1) * 500 * time.Millisecond
|
||||
t.Logf("Waiting %v before retry...", delay)
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, resp, err
|
||||
}
|
||||
func testGetLogsAPI(t *testing.T, targetID string) {
|
||||
t.Logf("Testing get_logs API with target: %s", targetID)
|
||||
|
||||
// First verify HTTP connectivity
|
||||
healthURL := fmt.Sprintf("http://%s:%s/health", apiHost, apiPort)
|
||||
resp, err := http.Get(healthURL)
|
||||
if err != nil {
|
||||
t.Fatalf("HTTP health check failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
t.Logf("HTTP health check passed: status=%d", resp.StatusCode)
|
||||
|
||||
// Test WebSocket endpoint with HTTP GET (should return 400 or upgrade required)
|
||||
wsHTTPURL := fmt.Sprintf("http://%s:%s/ws", apiHost, apiPort)
|
||||
resp, err = http.Get(wsHTTPURL)
|
||||
if err != nil {
|
||||
t.Logf("HTTP GET to /ws failed: %v", err)
|
||||
} else {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
t.Logf("HTTP GET /ws: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort)
|
||||
|
||||
// Connect to WebSocket with retry
|
||||
conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to WebSocket after retries: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Build binary message for get_logs
|
||||
// [opcode:1][api_key_hash:16][target_id_len:1][target_id:var]
|
||||
opcode := byte(0x20) // OpcodeGetLogs
|
||||
apiKeyHash := make([]byte, 16) // Zero-filled for testing (no auth)
|
||||
targetIDBytes := []byte(targetID)
|
||||
|
||||
message := []byte{opcode}
|
||||
message = append(message, apiKeyHash...)
|
||||
message = append(message, byte(len(targetIDBytes)))
|
||||
message = append(message, targetIDBytes...)
|
||||
|
||||
t.Logf("Sending message: opcode=0x%02x, len=%d, payload_len=%d", opcode, len(message), len(message)-1)
|
||||
|
||||
// Send message
|
||||
err = conn.WriteMessage(websocket.BinaryMessage, message)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send get_logs message: %v", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
messageType, response, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read get_logs response: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Received response type %d, length %d bytes", messageType, len(response))
|
||||
|
||||
// Parse response
|
||||
if len(response) < 1 {
|
||||
t.Fatal("Empty response received")
|
||||
}
|
||||
|
||||
// Parse the packet using the protocol
|
||||
packet, err := parseResponsePacket(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response packet: %v", err)
|
||||
}
|
||||
|
||||
// Verify response is a data packet (0x04 = PacketTypeData)
|
||||
if packet.PacketType != 0x04 {
|
||||
t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType)
|
||||
}
|
||||
|
||||
// Parse JSON payload
|
||||
var logsResponse struct {
|
||||
TargetID string `json:"target_id"`
|
||||
Logs string `json:"logs"`
|
||||
Truncated bool `json:"truncated"`
|
||||
TotalLines int `json:"total_lines"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(packet.Payload, &logsResponse); err != nil {
|
||||
t.Fatalf("Failed to parse logs JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Verify response fields
|
||||
if logsResponse.TargetID != targetID {
|
||||
t.Errorf("Expected target_id %s, got %s", targetID, logsResponse.TargetID)
|
||||
}
|
||||
|
||||
if logsResponse.Logs == "" {
|
||||
t.Error("Expected non-empty logs content")
|
||||
}
|
||||
|
||||
t.Logf("Successfully received logs for target %s (%d lines, truncated=%v)",
|
||||
logsResponse.TargetID, logsResponse.TotalLines, logsResponse.Truncated)
|
||||
}
|
||||
|
||||
// testStreamLogsAPI tests the stream_logs WebSocket endpoint
|
||||
func testStreamLogsAPI(t *testing.T, targetID string) {
|
||||
t.Logf("Testing stream_logs API with target: %s", targetID)
|
||||
|
||||
wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort)
|
||||
|
||||
// Connect to WebSocket with retry
|
||||
conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to WebSocket after retries: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Build binary message for stream_logs
|
||||
// [opcode:1][api_key_hash:16][target_id_len:1][target_id:var]
|
||||
opcode := byte(0x21) // OpcodeStreamLogs
|
||||
apiKeyHash := make([]byte, 16)
|
||||
targetIDBytes := []byte(targetID)
|
||||
|
||||
message := []byte{opcode}
|
||||
message = append(message, apiKeyHash...)
|
||||
message = append(message, byte(len(targetIDBytes)))
|
||||
message = append(message, targetIDBytes...)
|
||||
|
||||
// Send message
|
||||
err = conn.WriteMessage(websocket.BinaryMessage, message)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send stream_logs message: %v", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
messageType, response, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read stream_logs response: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Received stream response type %d, length %d bytes", messageType, len(response))
|
||||
|
||||
// Parse response
|
||||
packet, err := parseResponsePacket(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response packet: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
if packet.PacketType != 0x04 { // PacketTypeData
|
||||
t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType)
|
||||
}
|
||||
|
||||
var streamResponse struct {
|
||||
TargetID string `json:"target_id"`
|
||||
Streaming bool `json:"streaming"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(packet.Payload, &streamResponse); err != nil {
|
||||
t.Fatalf("Failed to parse stream JSON response: %v", err)
|
||||
}
|
||||
|
||||
if streamResponse.TargetID != targetID {
|
||||
t.Errorf("Expected target_id %s, got %s", targetID, streamResponse.TargetID)
|
||||
}
|
||||
|
||||
if !streamResponse.Streaming {
|
||||
t.Error("Expected streaming=true in response")
|
||||
}
|
||||
|
||||
t.Logf("Successfully initiated log streaming for target %s", streamResponse.TargetID)
|
||||
}
|
||||
|
||||
// testAttachDebugAPI tests the attach_debug WebSocket endpoint
|
||||
func testAttachDebugAPI(t *testing.T, targetID, debugType string) {
|
||||
t.Logf("Testing attach_debug API with target: %s, debug_type: %s", targetID, debugType)
|
||||
|
||||
wsURL := fmt.Sprintf("ws://%s:%s/ws", apiHost, apiPort)
|
||||
|
||||
// Connect to WebSocket with retry
|
||||
conn, resp, err := connectWebSocketWithRetry(t, wsURL, 5)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to WebSocket after retries: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Build binary message for attach_debug
|
||||
// [opcode:1][api_key_hash:16][target_id_len:1][target_id:var][debug_type:var]
|
||||
opcode := byte(0x22) // OpcodeAttachDebug
|
||||
apiKeyHash := make([]byte, 16)
|
||||
targetIDBytes := []byte(targetID)
|
||||
debugTypeBytes := []byte(debugType)
|
||||
|
||||
message := []byte{opcode}
|
||||
message = append(message, apiKeyHash...)
|
||||
message = append(message, byte(len(targetIDBytes)))
|
||||
message = append(message, targetIDBytes...)
|
||||
message = append(message, debugTypeBytes...)
|
||||
|
||||
// Send message
|
||||
err = conn.WriteMessage(websocket.BinaryMessage, message)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send attach_debug message: %v", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
messageType, response, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read attach_debug response: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Received debug response type %d, length %d bytes", messageType, len(response))
|
||||
|
||||
// Parse response
|
||||
packet, err := parseResponsePacket(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response packet: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
if packet.PacketType != 0x04 { // PacketTypeData
|
||||
t.Errorf("Expected data packet (0x04), got 0x%02x", packet.PacketType)
|
||||
}
|
||||
|
||||
var debugResponse struct {
|
||||
TargetID string `json:"target_id"`
|
||||
DebugType string `json:"debug_type"`
|
||||
Attached bool `json:"attached"`
|
||||
Message string `json:"message"`
|
||||
Suggestion string `json:"suggestion"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(packet.Payload, &debugResponse); err != nil {
|
||||
t.Fatalf("Failed to parse debug JSON response: %v", err)
|
||||
}
|
||||
|
||||
if debugResponse.TargetID != targetID {
|
||||
t.Errorf("Expected target_id %s, got %s", targetID, debugResponse.TargetID)
|
||||
}
|
||||
|
||||
if debugResponse.DebugType != debugType {
|
||||
t.Errorf("Expected debug_type %s, got %s", debugType, debugResponse.DebugType)
|
||||
}
|
||||
|
||||
// Note: attached is false in stub implementation
|
||||
t.Logf("Debug attachment response: target=%s, type=%s, attached=%v",
|
||||
debugResponse.TargetID, debugResponse.DebugType, debugResponse.Attached)
|
||||
}
|
||||
|
||||
// ResponsePacket represents a parsed WebSocket response packet
|
||||
type ResponsePacket struct {
|
||||
PacketType byte
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
// parseResponsePacket parses a binary WebSocket response packet
|
||||
func parseResponsePacket(data []byte) (*ResponsePacket, error) {
|
||||
if len(data) < 9 { // min: type(1) + timestamp(8)
|
||||
return nil, fmt.Errorf("packet too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
packetType := data[0]
|
||||
// Skip timestamp (8 bytes)
|
||||
payloadStart := 9
|
||||
|
||||
// Handle different response types
|
||||
switch packetType {
|
||||
case 0x00: // PacketTypeSuccess
|
||||
if len(data) <= payloadStart {
|
||||
return &ResponsePacket{PacketType: packetType, Payload: nil}, nil
|
||||
}
|
||||
// Parse varint length + string
|
||||
strLen, n := binary.Uvarint(data[payloadStart:])
|
||||
if n <= 0 {
|
||||
return nil, fmt.Errorf("invalid varint")
|
||||
}
|
||||
return &ResponsePacket{
|
||||
PacketType: packetType,
|
||||
Payload: data[payloadStart+n : payloadStart+n+int(strLen)],
|
||||
}, nil
|
||||
|
||||
case 0x01: // PacketTypeError
|
||||
if len(data) < payloadStart+1 {
|
||||
return nil, fmt.Errorf("error packet too short")
|
||||
}
|
||||
errorCode := data[payloadStart]
|
||||
// Skip error code and read message
|
||||
msgStart := payloadStart + 1
|
||||
if len(data) > msgStart && len(data) > msgStart+1 {
|
||||
return nil, fmt.Errorf("server error (code: 0x%02x): %s", errorCode, string(data[msgStart:]))
|
||||
}
|
||||
return nil, fmt.Errorf("server error (code: 0x%02x)", errorCode)
|
||||
|
||||
case 0x04: // PacketTypeData
|
||||
// Format after timestamp: [data_type_len:varint][data_type][payload_len:varint][payload]
|
||||
if len(data) <= payloadStart {
|
||||
return &ResponsePacket{PacketType: packetType, Payload: nil}, nil
|
||||
}
|
||||
|
||||
// Read data_type length (varint)
|
||||
typeLen, n1 := binary.Uvarint(data[payloadStart:])
|
||||
if n1 <= 0 {
|
||||
return nil, fmt.Errorf("invalid type length varint")
|
||||
}
|
||||
|
||||
// Skip data_type string
|
||||
payloadLenStart := payloadStart + n1 + int(typeLen)
|
||||
if len(data) <= payloadLenStart {
|
||||
return &ResponsePacket{PacketType: packetType, Payload: nil}, nil
|
||||
}
|
||||
|
||||
// Read payload length (varint)
|
||||
payloadLen, n2 := binary.Uvarint(data[payloadLenStart:])
|
||||
if n2 <= 0 {
|
||||
return nil, fmt.Errorf("invalid payload length varint")
|
||||
}
|
||||
|
||||
payloadDataStart := payloadLenStart + n2
|
||||
if len(data) < payloadDataStart+int(payloadLen) {
|
||||
return nil, fmt.Errorf("data packet payload incomplete")
|
||||
}
|
||||
|
||||
return &ResponsePacket{
|
||||
PacketType: packetType,
|
||||
Payload: data[payloadDataStart : payloadDataStart+int(payloadLen)],
|
||||
}, nil
|
||||
|
||||
default:
|
||||
// Unknown packet type, return raw data after timestamp
|
||||
if len(data) > payloadStart {
|
||||
return &ResponsePacket{
|
||||
PacketType: packetType,
|
||||
Payload: data[payloadStart:],
|
||||
}, nil
|
||||
}
|
||||
return &ResponsePacket{PacketType: packetType}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupDockerCompose stops and removes Docker Compose containers
|
||||
func cleanupDockerCompose(t *testing.T, composeFile string) {
|
||||
t.Log("Cleaning up Docker Compose...")
|
||||
downCmd := exec.CommandContext(context.Background(),
|
||||
"docker-compose", "-f", composeFile, "-p", "logsdebug-e2e", "down", "--remove-orphans", "--volumes")
|
||||
downCmd.Stdout = os.Stdout
|
||||
downCmd.Stderr = os.Stderr
|
||||
if err := downCmd.Run(); err != nil {
|
||||
t.Logf("Warning: Failed to cleanup Docker Compose: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// waitForServicesHealthy waits for all services to be healthy
|
||||
func waitForServicesHealthy(t *testing.T, projectName string, timeout time.Duration) bool {
|
||||
start := time.Now()
|
||||
for time.Since(start) < timeout {
|
||||
// Check container status
|
||||
psCmd := exec.CommandContext(context.Background(),
|
||||
"docker", "ps", "--filter", fmt.Sprintf("name=%s", projectName), "--format", "{{.Names}}\t{{.Status}}")
|
||||
output, err := psCmd.CombinedOutput()
|
||||
if err == nil {
|
||||
status := string(output)
|
||||
t.Logf("Container status:\n%s", status)
|
||||
|
||||
// Check if all containers are healthy (not just "Up")
|
||||
allHealthy := true
|
||||
lines := strings.Split(status, "\n")
|
||||
containerCount := 0
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
containerCount++
|
||||
// Must explicitly contain "healthy" - "Up (health: starting)" is NOT healthy
|
||||
if !strings.Contains(line, "healthy") {
|
||||
allHealthy = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allHealthy && containerCount >= 2 { // redis and api-server
|
||||
t.Log("All services are healthy")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
t.Logf("Timeout waiting for services after %v", timeout)
|
||||
return false
|
||||
}
|
||||
|
||||
// TestLogsDebugHTTPHealthE2E tests the HTTP health endpoint for logs/debug services
|
||||
func TestLogsDebugHTTPHealthE2E(t *testing.T) {
|
||||
if os.Getenv("FETCH_ML_E2E_LOGS_DEBUG") != "1" {
|
||||
t.Skip("Skipping LogsDebugHTTPHealthE2E (set FETCH_ML_E2E_LOGS_DEBUG=1 to enable)")
|
||||
}
|
||||
|
||||
// Check if API is already running
|
||||
healthURL := fmt.Sprintf("http://%s:%s/health", apiHost, apiPort)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Get(healthURL)
|
||||
if err != nil {
|
||||
t.Skipf("API server not available at %s: %v", healthURL, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Health check failed: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse health response
|
||||
var healthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&healthResponse); err != nil {
|
||||
t.Logf("Failed to parse health response: %v", err)
|
||||
} else {
|
||||
t.Logf("API health: status=%s, version=%s", healthResponse.Status, healthResponse.Version)
|
||||
}
|
||||
}
|
||||
63
tests/integration/duplicate_detection_test.go
Normal file
63
tests/integration/duplicate_detection_test.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// TestDuplicateDetection verifies the duplicate detection logic
|
||||
func TestDuplicateDetection(t *testing.T) {
|
||||
// Test 1: Same args + same commit = duplicate (would be detected)
|
||||
commitID := "abc123def4567890abc1"
|
||||
args1 := "--epochs 10 --lr 0.001"
|
||||
args2 := "--epochs 10 --lr 0.001" // Same args
|
||||
|
||||
hash1 := helpers.ComputeParamsHash(args1)
|
||||
hash2 := helpers.ComputeParamsHash(args2)
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Error("Same args should produce same hash")
|
||||
}
|
||||
t.Logf("✓ Same args produce same hash: %s", hash1)
|
||||
|
||||
// Test 2: Different args = not duplicate
|
||||
args3 := "--epochs 20 --lr 0.01" // Different
|
||||
hash3 := helpers.ComputeParamsHash(args3)
|
||||
|
||||
if hash1 == hash3 {
|
||||
t.Error("Different args should produce different hashes")
|
||||
}
|
||||
t.Logf("✓ Different args produce different hashes: %s vs %s", hash1, hash3)
|
||||
|
||||
// Test 3: Same dataset specs = same dataset_id
|
||||
ds1 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}}
|
||||
ds2 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}}
|
||||
|
||||
id1 := helpers.ComputeDatasetID(ds1, nil)
|
||||
id2 := helpers.ComputeDatasetID(ds2, nil)
|
||||
|
||||
if id1 != id2 {
|
||||
t.Error("Same dataset specs should produce same ID")
|
||||
}
|
||||
t.Logf("✓ Same dataset specs produce same ID: %s", id1)
|
||||
|
||||
// Test 4: Different dataset = different ID
|
||||
ds3 := []queue.DatasetSpec{{Name: "cifar10", Checksum: "sha256:def456"}}
|
||||
id3 := helpers.ComputeDatasetID(ds3, nil)
|
||||
|
||||
if id1 == id3 {
|
||||
t.Error("Different datasets should produce different IDs")
|
||||
}
|
||||
t.Logf("✓ Different datasets produce different IDs: %s vs %s", id1, id3)
|
||||
|
||||
// Test 5: Composite key logic
|
||||
t.Log("\n=== Composite Key Detection ===")
|
||||
t.Logf("Job 1: commit=%s, dataset_id=%s, params_hash=%s", commitID, id1, hash1)
|
||||
t.Logf("Job 2: commit=%s, dataset_id=%s, params_hash=%s", commitID, id2, hash2)
|
||||
t.Log("→ These would be detected as DUPLICATE (same commit + dataset + params)")
|
||||
|
||||
t.Logf("Job 3: commit=%s, dataset_id=%s, params_hash=%s", commitID, id1, hash3)
|
||||
t.Log("→ This would NOT be duplicate (different params)")
|
||||
}
|
||||
|
|
@ -236,10 +236,10 @@ func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
|
|||
expMgr,
|
||||
"",
|
||||
taskQueue,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil, // db
|
||||
nil, // jupyterServiceMgr
|
||||
nil, // securityConfig
|
||||
nil, // auditLogger
|
||||
)
|
||||
server := httptest.NewServer(wsHandler)
|
||||
defer server.Close()
|
||||
|
|
|
|||
|
|
@ -65,9 +65,9 @@ func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) (
|
|||
dataDir,
|
||||
tq,
|
||||
db,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil, // jupyterServiceMgr
|
||||
nil, // securityConfig
|
||||
nil, // auditLogger
|
||||
)
|
||||
server := httptest.NewServer(handler)
|
||||
return server, tq, expManager, s, db
|
||||
|
|
|
|||
95
tests/unit/api/duplicate_detection_process_test.go
Normal file
95
tests/unit/api/duplicate_detection_process_test.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package api_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// ProcessTest demonstrates the duplicate detection process step by step
|
||||
func TestDuplicateDetectionProcess(t *testing.T) {
|
||||
t.Log("=== Duplicate Detection Process Test ===")
|
||||
|
||||
// Step 1: First job submission
|
||||
t.Log("\n1. First job submission:")
|
||||
commitID := "abc123def456"
|
||||
args1 := "--epochs 10 --lr 0.001"
|
||||
datasets := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}}
|
||||
|
||||
datasetID1 := helpers.ComputeDatasetID(datasets, nil)
|
||||
paramsHash1 := helpers.ComputeParamsHash(args1)
|
||||
|
||||
t.Logf(" Commit ID: %s", commitID)
|
||||
t.Logf(" Dataset ID: %s (computed from %d datasets)", datasetID1, len(datasets))
|
||||
t.Logf(" Params Hash: %s (computed from args: %s)", paramsHash1, args1)
|
||||
t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID1, paramsHash1)
|
||||
|
||||
// Step 2: Second job with SAME parameters (should be duplicate)
|
||||
t.Log("\n2. Second job submission (same params):")
|
||||
args2 := "--epochs 10 --lr 0.001" // Same args
|
||||
datasets2 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} // Same dataset
|
||||
|
||||
datasetID2 := helpers.ComputeDatasetID(datasets2, nil)
|
||||
paramsHash2 := helpers.ComputeParamsHash(args2)
|
||||
|
||||
t.Logf(" Commit ID: %s", commitID)
|
||||
t.Logf(" Dataset ID: %s", datasetID2)
|
||||
t.Logf(" Params Hash: %s", paramsHash2)
|
||||
t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID2, paramsHash2)
|
||||
|
||||
// Verify they're the same
|
||||
if datasetID1 == datasetID2 && paramsHash1 == paramsHash2 {
|
||||
t.Log(" ✓ DUPLICATE DETECTED - same composite key!")
|
||||
} else {
|
||||
t.Error(" ✗ Should have been detected as duplicate")
|
||||
}
|
||||
|
||||
// Step 3: Third job with DIFFERENT parameters (not duplicate)
|
||||
t.Log("\n3. Third job submission (different params):")
|
||||
args3 := "--epochs 20 --lr 0.01" // Different args
|
||||
datasets3 := []queue.DatasetSpec{{Name: "mnist", Checksum: "sha256:abc123"}} // Same dataset
|
||||
|
||||
datasetID3 := helpers.ComputeDatasetID(datasets3, nil)
|
||||
paramsHash3 := helpers.ComputeParamsHash(args3)
|
||||
|
||||
t.Logf(" Commit ID: %s", commitID)
|
||||
t.Logf(" Dataset ID: %s", datasetID3)
|
||||
t.Logf(" Params Hash: %s", paramsHash3)
|
||||
t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID3, paramsHash3)
|
||||
|
||||
// Verify they're different
|
||||
if paramsHash1 != paramsHash3 {
|
||||
t.Log(" ✓ NOT A DUPLICATE - different params_hash")
|
||||
} else {
|
||||
t.Error(" ✗ Should have different params_hash")
|
||||
}
|
||||
|
||||
// Step 4: Fourth job with DIFFERENT dataset (not duplicate)
|
||||
t.Log("\n4. Fourth job submission (different dataset):")
|
||||
args4 := "--epochs 10 --lr 0.001" // Same args
|
||||
datasets4 := []queue.DatasetSpec{{Name: "cifar10", Checksum: "sha256:def456"}} // Different dataset
|
||||
|
||||
datasetID4 := helpers.ComputeDatasetID(datasets4, nil)
|
||||
paramsHash4 := helpers.ComputeParamsHash(args4)
|
||||
|
||||
t.Logf(" Commit ID: %s", commitID)
|
||||
t.Logf(" Dataset ID: %s", datasetID4)
|
||||
t.Logf(" Params Hash: %s", paramsHash4)
|
||||
t.Logf(" Composite Key: (%s, %s, %s)", commitID, datasetID4, paramsHash4)
|
||||
|
||||
// Verify they're different
|
||||
if datasetID1 != datasetID4 {
|
||||
t.Log(" ✓ NOT A DUPLICATE - different dataset_id")
|
||||
} else {
|
||||
t.Error(" ✗ Should have different dataset_id")
|
||||
}
|
||||
|
||||
// Step 5: Summary
|
||||
t.Log("\n5. Summary:")
|
||||
t.Log(" - Jobs 1 & 2: Same commit_id + dataset_id + params_hash = DUPLICATE")
|
||||
t.Log(" - Job 3: Different params_hash = NOT DUPLICATE")
|
||||
t.Log(" - Job 4: Different dataset_id = NOT DUPLICATE")
|
||||
t.Log("\n The composite key (commit_id, dataset_id, params_hash) ensures")
|
||||
t.Log(" only truly identical experiments are flagged as duplicates.")
|
||||
}
|
||||
225
tests/unit/api/helpers/db_helpers_test.go
Normal file
225
tests/unit/api/helpers/db_helpers_test.go
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
package helpers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
)
|
||||
|
||||
func TestDBContext(t *testing.T) {
|
||||
ctx, cancel := helpers.DBContext(3 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("expected deadline to be set")
|
||||
}
|
||||
|
||||
// Deadline should be within a reasonable time window
|
||||
now := time.Now()
|
||||
if deadline.Before(now) {
|
||||
t.Error("deadline is in the past")
|
||||
}
|
||||
if deadline.After(now.Add(4 * time.Second)) {
|
||||
t.Error("deadline is too far in the future")
|
||||
}
|
||||
|
||||
// Context should not be cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("context should not be cancelled yet")
|
||||
default:
|
||||
// Good
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBContextShort(t *testing.T) {
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("expected deadline to be set")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
diff := deadline.Sub(now)
|
||||
|
||||
// Should be around 3 seconds
|
||||
if diff < 2*time.Second || diff > 4*time.Second {
|
||||
t.Errorf("expected deadline ~3s, got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBContextMedium(t *testing.T) {
|
||||
ctx, cancel := helpers.DBContextMedium()
|
||||
defer cancel()
|
||||
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("expected deadline to be set")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
diff := deadline.Sub(now)
|
||||
|
||||
// Should be around 5 seconds
|
||||
if diff < 4*time.Second || diff > 6*time.Second {
|
||||
t.Errorf("expected deadline ~5s, got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBContextLong(t *testing.T) {
|
||||
ctx, cancel := helpers.DBContextLong()
|
||||
defer cancel()
|
||||
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("expected deadline to be set")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
diff := deadline.Sub(now)
|
||||
|
||||
// Should be around 10 seconds
|
||||
if diff < 9*time.Second || diff > 11*time.Second {
|
||||
t.Errorf("expected deadline ~10s, got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBContextCancellation(t *testing.T) {
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
|
||||
// Cancel immediately
|
||||
cancel()
|
||||
|
||||
// Context should be cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Good - context was cancelled
|
||||
default:
|
||||
t.Error("context should be cancelled after calling cancel()")
|
||||
}
|
||||
|
||||
// Check error
|
||||
if ctx.Err() == nil {
|
||||
t.Error("expected non-nil error after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSliceContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
item string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "contains",
|
||||
slice: []string{"a", "b", "c"},
|
||||
item: "b",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "not contains",
|
||||
slice: []string{"a", "b", "c"},
|
||||
item: "d",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty slice",
|
||||
slice: []string{},
|
||||
item: "a",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil slice",
|
||||
slice: nil,
|
||||
item: "a",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty string item",
|
||||
slice: []string{"a", "", "c"},
|
||||
item: "",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "case sensitive",
|
||||
slice: []string{"Apple", "Banana"},
|
||||
item: "apple",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.StringSliceContains(tt.slice, tt.item)
|
||||
if got != tt.want {
|
||||
t.Errorf("StringSliceContains() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSliceFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
predicate func(string) bool
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "filter by prefix",
|
||||
slice: []string{"apple", "banana", "apricot", "cherry"},
|
||||
predicate: func(s string) bool { return len(s) > 0 && s[0] == 'a' },
|
||||
want: []string{"apple", "apricot"},
|
||||
},
|
||||
{
|
||||
name: "filter empty strings",
|
||||
slice: []string{"a", "", "b", "", "c"},
|
||||
predicate: func(s string) bool { return s != "" },
|
||||
want: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "all match",
|
||||
slice: []string{"a", "b", "c"},
|
||||
predicate: func(s string) bool { return true },
|
||||
want: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "none match",
|
||||
slice: []string{"a", "b", "c"},
|
||||
predicate: func(s string) bool { return false },
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "empty slice",
|
||||
slice: []string{},
|
||||
predicate: func(s string) bool { return true },
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "nil slice",
|
||||
slice: nil,
|
||||
predicate: func(s string) bool { return true },
|
||||
want: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.StringSliceFilter(tt.slice, tt.predicate)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("StringSliceFilter() returned %d items, want %d", len(got), len(tt.want))
|
||||
return
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("StringSliceFilter()[%d] = %q, want %q", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
137
tests/unit/api/helpers/hash_helpers_test.go
Normal file
137
tests/unit/api/helpers/hash_helpers_test.go
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
package helpers_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
func TestComputeDatasetID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
datasetSpecs []queue.DatasetSpec
|
||||
datasets []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "both empty",
|
||||
datasetSpecs: nil,
|
||||
datasets: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "only datasets",
|
||||
datasetSpecs: nil,
|
||||
datasets: []string{"dataset1", "dataset2"},
|
||||
want: "", // will be a hash
|
||||
},
|
||||
{
|
||||
name: "dataset specs with checksums",
|
||||
datasetSpecs: []queue.DatasetSpec{
|
||||
{Name: "ds1", Checksum: "abc123"},
|
||||
{Name: "ds2", Checksum: "def456"},
|
||||
},
|
||||
datasets: nil,
|
||||
want: "", // will be a hash
|
||||
},
|
||||
{
|
||||
name: "dataset specs without checksums",
|
||||
datasetSpecs: []queue.DatasetSpec{
|
||||
{Name: "ds1"},
|
||||
{Name: "ds2"},
|
||||
},
|
||||
datasets: nil,
|
||||
want: "", // will use names
|
||||
},
|
||||
{
|
||||
name: "checksums take precedence",
|
||||
datasetSpecs: []queue.DatasetSpec{{Name: "ds1", Checksum: "xyz789"}},
|
||||
datasets: []string{"dataset1"},
|
||||
want: "", // should use checksum from specs
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.ComputeDatasetID(tt.datasetSpecs, tt.datasets)
|
||||
if tt.want == "" {
|
||||
// Just verify it returns something or empty as expected
|
||||
if len(tt.datasetSpecs) == 0 && len(tt.datasets) == 0 && got != "" {
|
||||
t.Errorf("ComputeDatasetID() = %q, want empty string", got)
|
||||
}
|
||||
} else if got != tt.want {
|
||||
t.Errorf("ComputeDatasetID() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeParamsHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty args",
|
||||
args: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "simple args",
|
||||
args: "--lr 0.01 --epochs 10",
|
||||
want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // sha256 of trimmed args
|
||||
},
|
||||
{
|
||||
name: "args with spaces",
|
||||
args: " --lr 0.01 ",
|
||||
want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // sha256 of trimmed args
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.ComputeParamsHash(tt.args)
|
||||
// Just verify it returns a string (hash computation is deterministic)
|
||||
if tt.want == "" && got != "" {
|
||||
t.Errorf("ComputeParamsHash() expected empty, got %q", got)
|
||||
}
|
||||
if tt.want != "" && got == "" {
|
||||
t.Errorf("ComputeParamsHash() expected non-empty hash")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeParamsHash_Deterministic(t *testing.T) {
|
||||
args := "--lr 0.01 --epochs 10"
|
||||
hash1 := helpers.ComputeParamsHash(args)
|
||||
hash2 := helpers.ComputeParamsHash(args)
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Error("ComputeParamsHash() should be deterministic")
|
||||
}
|
||||
|
||||
// Different args should produce different hashes
|
||||
differentArgs := "--lr 0.02 --epochs 10"
|
||||
differentHash := helpers.ComputeParamsHash(differentArgs)
|
||||
|
||||
if hash1 == differentHash {
|
||||
t.Error("ComputeParamsHash() should produce different hashes for different inputs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeParamsHash_Whitespace(t *testing.T) {
|
||||
// Same args with different whitespace should produce the same hash
|
||||
hash1 := helpers.ComputeParamsHash("--lr 0.01 --epochs 10")
|
||||
hash2 := helpers.ComputeParamsHash(" --lr 0.01 --epochs 10 ")
|
||||
hash3 := helpers.ComputeParamsHash("--lr 0.01 --epochs 10")
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Error("ComputeParamsHash() should handle leading/trailing whitespace consistently")
|
||||
}
|
||||
// Note: internal whitespace differences may or may not produce different hashes
|
||||
// depending on implementation details
|
||||
_ = hash3
|
||||
}
|
||||
451
tests/unit/api/helpers/payload_parser_test.go
Normal file
451
tests/unit/api/helpers/payload_parser_test.go
Normal file
|
|
@ -0,0 +1,451 @@
|
|||
package helpers_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
)
|
||||
|
||||
func TestNewPayloadParser(t *testing.T) {
|
||||
payload := []byte{0x01, 0x02, 0x03, 0x04, 0x05}
|
||||
p := helpers.NewPayloadParser(payload, 2)
|
||||
|
||||
if p.Offset() != 2 {
|
||||
t.Errorf("expected offset 2, got %d", p.Offset())
|
||||
}
|
||||
if !bytes.Equal(p.Payload(), payload) {
|
||||
t.Error("payload mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_ParseByte(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
start int
|
||||
want byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid byte",
|
||||
payload: []byte{0x00, 0x01, 0x02},
|
||||
start: 1,
|
||||
want: 0x01,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "end of payload",
|
||||
payload: []byte{0x00, 0x01},
|
||||
start: 2,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty payload",
|
||||
payload: []byte{},
|
||||
start: 0,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := helpers.NewPayloadParser(tt.payload, tt.start)
|
||||
got, err := p.ParseByte()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseByte() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseByte() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_ParseUint16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
start int
|
||||
want uint16
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid uint16",
|
||||
payload: []byte{0x01, 0x02, 0x03, 0x04},
|
||||
start: 0,
|
||||
want: 0x0102,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "another valid uint16",
|
||||
payload: []byte{0x00, 0xAB, 0xCD},
|
||||
start: 1,
|
||||
want: 0xABCD,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too short",
|
||||
payload: []byte{0x00, 0x01},
|
||||
start: 1,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty at end",
|
||||
payload: []byte{0x00, 0x01, 0x02},
|
||||
start: 3,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := helpers.NewPayloadParser(tt.payload, tt.start)
|
||||
got, err := p.ParseUint16()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseUint16() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseUint16() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_ParseLengthPrefixedString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
start int
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid string",
|
||||
payload: []byte{0x05, 'h', 'e', 'l', 'l', 'o'},
|
||||
start: 0,
|
||||
want: "hello",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
payload: []byte{0x00, 0x01, 0x02},
|
||||
start: 0,
|
||||
want: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "from offset",
|
||||
payload: []byte{0x00, 0x03, 'f', 'o', 'o'},
|
||||
start: 1,
|
||||
want: "foo",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too short for length",
|
||||
payload: []byte{0x05, 'h', 'i'},
|
||||
start: 0,
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no length byte",
|
||||
payload: []byte{},
|
||||
start: 0,
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := helpers.NewPayloadParser(tt.payload, tt.start)
|
||||
got, err := p.ParseLengthPrefixedString()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseLengthPrefixedString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseLengthPrefixedString() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_ParseUint16PrefixedString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
start int
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid string",
|
||||
payload: []byte{0x00, 0x05, 'h', 'e', 'l', 'l', 'o'},
|
||||
start: 0,
|
||||
want: "hello",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
payload: []byte{0x00, 0x00, 0x01, 0x02},
|
||||
start: 0,
|
||||
want: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "from offset",
|
||||
payload: []byte{0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r'},
|
||||
start: 2,
|
||||
want: "bar",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too short for length",
|
||||
payload: []byte{0x00},
|
||||
start: 0,
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "too short for string",
|
||||
payload: []byte{0x00, 0x05, 'h', 'i'},
|
||||
start: 0,
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := helpers.NewPayloadParser(tt.payload, tt.start)
|
||||
got, err := p.ParseUint16PrefixedString()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseUint16PrefixedString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseUint16PrefixedString() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_ParseBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
start int
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
payload: []byte{0x01, 0x00},
|
||||
start: 0,
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "false",
|
||||
payload: []byte{0x00, 0x01},
|
||||
start: 0,
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-zero is true",
|
||||
payload: []byte{0xFF},
|
||||
start: 0,
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
payload: []byte{},
|
||||
start: 0,
|
||||
want: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := helpers.NewPayloadParser(tt.payload, tt.start)
|
||||
got, err := p.ParseBool()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseBool() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseBool() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_ParseFixedBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
start int
|
||||
length int
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
payload: []byte{0x00, 0x01, 0x02, 0x03, 0x04},
|
||||
start: 1,
|
||||
length: 3,
|
||||
want: []byte{0x01, 0x02, 0x03},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "exact length",
|
||||
payload: []byte{0x00, 0x01, 0x02},
|
||||
start: 0,
|
||||
length: 3,
|
||||
want: []byte{0x00, 0x01, 0x02},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
payload: []byte{0x00, 0x01},
|
||||
start: 0,
|
||||
length: 3,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "from end",
|
||||
payload: []byte{0x00, 0x01, 0x02},
|
||||
start: 2,
|
||||
length: 2,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := helpers.NewPayloadParser(tt.payload, tt.start)
|
||||
got, err := p.ParseFixedBytes(tt.length)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseFixedBytes() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(got, tt.want) {
|
||||
t.Errorf("ParseFixedBytes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_Offset(t *testing.T) {
|
||||
payload := []byte{0x00, 0x01, 0x02, 0x03}
|
||||
p := helpers.NewPayloadParser(payload, 1)
|
||||
|
||||
if p.Offset() != 1 {
|
||||
t.Errorf("initial Offset() = %d, want 1", p.Offset())
|
||||
}
|
||||
|
||||
p.ParseByte()
|
||||
if p.Offset() != 2 {
|
||||
t.Errorf("after ParseByte() Offset() = %d, want 2", p.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_Remaining(t *testing.T) {
|
||||
payload := []byte{0x00, 0x01, 0x02, 0x03}
|
||||
p := helpers.NewPayloadParser(payload, 2)
|
||||
|
||||
remaining := p.Remaining()
|
||||
if !bytes.Equal(remaining, []byte{0x02, 0x03}) {
|
||||
t.Errorf("Remaining() = %v, want [0x02, 0x03]", remaining)
|
||||
}
|
||||
|
||||
// At end
|
||||
p = helpers.NewPayloadParser(payload, 4)
|
||||
if p.Remaining() != nil {
|
||||
t.Errorf("at end Remaining() = %v, want nil", p.Remaining())
|
||||
}
|
||||
|
||||
// Beyond end
|
||||
p = helpers.NewPayloadParser(payload, 5)
|
||||
if p.Remaining() != nil {
|
||||
t.Errorf("beyond end Remaining() = %v, want nil", p.Remaining())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_HasRemaining(t *testing.T) {
|
||||
payload := []byte{0x00, 0x01}
|
||||
|
||||
tests := []struct {
|
||||
start int
|
||||
want bool
|
||||
}{
|
||||
{0, true},
|
||||
{1, true},
|
||||
{2, false},
|
||||
{3, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
p := helpers.NewPayloadParser(payload, tt.start)
|
||||
if got := p.HasRemaining(); got != tt.want {
|
||||
t.Errorf("HasRemaining() with start=%d = %v, want %v", tt.start, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadParser_ChainedParsing(t *testing.T) {
|
||||
// Simulate a real protocol: [api_key:2][name_len:1][name:var][value:2]
|
||||
payload := []byte{
|
||||
0xAB, 0xCD, // api_key
|
||||
0x05, // name_len
|
||||
'h', 'e', 'l', 'l', 'o', // name
|
||||
0x00, 0x42, // value
|
||||
}
|
||||
|
||||
p := helpers.NewPayloadParser(payload, 0)
|
||||
|
||||
// Parse api key (2 bytes)
|
||||
apiKey, err := p.ParseFixedBytes(2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse api key: %v", err)
|
||||
}
|
||||
if !bytes.Equal(apiKey, []byte{0xAB, 0xCD}) {
|
||||
t.Errorf("api key = %v, want [0xAB, 0xCD]", apiKey)
|
||||
}
|
||||
|
||||
// Parse name
|
||||
name, err := p.ParseLengthPrefixedString()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse name: %v", err)
|
||||
}
|
||||
if name != "hello" {
|
||||
t.Errorf("name = %q, want 'hello'", name)
|
||||
}
|
||||
|
||||
// Parse value
|
||||
value, err := p.ParseUint16()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse value: %v", err)
|
||||
}
|
||||
if value != 0x0042 {
|
||||
t.Errorf("value = 0x%X, want 0x0042", value)
|
||||
}
|
||||
|
||||
// Should be at end
|
||||
if p.HasRemaining() {
|
||||
t.Error("expected no remaining bytes")
|
||||
}
|
||||
}
|
||||
345
tests/unit/api/helpers/response_helpers_test.go
Normal file
345
tests/unit/api/helpers/response_helpers_test.go
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
package helpers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
func TestNewTaskErrorMapper(t *testing.T) {
|
||||
mapper := helpers.NewTaskErrorMapper()
|
||||
if mapper == nil {
|
||||
t.Error("NewTaskErrorMapper() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskErrorMapper_MapJupyterError(t *testing.T) {
|
||||
mapper := helpers.NewTaskErrorMapper()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
task *queue.Task
|
||||
want helpers.ErrorCode
|
||||
}{
|
||||
{
|
||||
name: "nil task",
|
||||
task: nil,
|
||||
want: 0x00, // ErrorCodeUnknownError
|
||||
},
|
||||
{
|
||||
name: "cancelled task",
|
||||
task: &queue.Task{Status: "cancelled"},
|
||||
want: 0x24, // ErrorCodeJobCancelled
|
||||
},
|
||||
{
|
||||
name: "oom error",
|
||||
task: &queue.Task{Status: "failed", Error: "out of memory"},
|
||||
want: 0x30, // ErrorCodeOutOfMemory
|
||||
},
|
||||
{
|
||||
name: "oom shorthand",
|
||||
task: &queue.Task{Status: "failed", Error: "OOM killed"},
|
||||
want: 0x30, // ErrorCodeOutOfMemory
|
||||
},
|
||||
{
|
||||
name: "disk full",
|
||||
task: &queue.Task{Status: "failed", Error: "no space left on device"},
|
||||
want: 0x31, // ErrorCodeDiskFull
|
||||
},
|
||||
{
|
||||
name: "disk full alt",
|
||||
task: &queue.Task{Status: "failed", Error: "disk full"},
|
||||
want: 0x31, // ErrorCodeDiskFull
|
||||
},
|
||||
{
|
||||
name: "rate limit",
|
||||
task: &queue.Task{Status: "failed", Error: "rate limit exceeded"},
|
||||
want: 0x33, // ErrorCodeServiceUnavailable
|
||||
},
|
||||
{
|
||||
name: "throttle",
|
||||
task: &queue.Task{Status: "failed", Error: "request throttled"},
|
||||
want: 0x33, // ErrorCodeServiceUnavailable
|
||||
},
|
||||
{
|
||||
name: "timeout",
|
||||
task: &queue.Task{Status: "failed", Error: "timed out waiting"},
|
||||
want: 0x14, // ErrorCodeTimeout
|
||||
},
|
||||
{
|
||||
name: "deadline",
|
||||
task: &queue.Task{Status: "failed", Error: "context deadline exceeded"},
|
||||
want: 0x14, // ErrorCodeTimeout
|
||||
},
|
||||
{
|
||||
name: "connection refused",
|
||||
task: &queue.Task{Status: "failed", Error: "connection refused"},
|
||||
want: 0x12, // ErrorCodeNetworkError
|
||||
},
|
||||
{
|
||||
name: "connection reset",
|
||||
task: &queue.Task{Status: "failed", Error: "connection reset by peer"},
|
||||
want: 0x12, // ErrorCodeNetworkError
|
||||
},
|
||||
{
|
||||
name: "network unreachable",
|
||||
task: &queue.Task{Status: "failed", Error: "network unreachable"},
|
||||
want: 0x12, // ErrorCodeNetworkError
|
||||
},
|
||||
{
|
||||
name: "queue not configured",
|
||||
task: &queue.Task{Status: "failed", Error: "queue not configured"},
|
||||
want: 0x32, // ErrorCodeInvalidConfiguration
|
||||
},
|
||||
{
|
||||
name: "generic failed",
|
||||
task: &queue.Task{Status: "failed", Error: "something went wrong"},
|
||||
want: 0x23, // ErrorCodeJobExecutionFailed
|
||||
},
|
||||
{
|
||||
name: "unknown status",
|
||||
task: &queue.Task{Status: "unknown", Error: "unknown error"},
|
||||
want: 0x00, // ErrorCodeUnknownError
|
||||
},
|
||||
{
|
||||
name: "case insensitive - cancelled",
|
||||
task: &queue.Task{Status: "CANCELLED"},
|
||||
want: 0x24, // ErrorCodeJobCancelled
|
||||
},
|
||||
{
|
||||
name: "case insensitive - oom",
|
||||
task: &queue.Task{Status: "FAILED", Error: "OUT OF MEMORY"},
|
||||
want: 0x30, // ErrorCodeOutOfMemory
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := mapper.MapJupyterError(tt.task)
|
||||
if got != tt.want {
|
||||
t.Errorf("MapJupyterError() = 0x%02X, want 0x%02X", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskErrorMapper_MapError(t *testing.T) {
|
||||
mapper := helpers.NewTaskErrorMapper()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
task *queue.Task
|
||||
defaultCode helpers.ErrorCode
|
||||
want helpers.ErrorCode
|
||||
}{
|
||||
{
|
||||
name: "nil task returns default",
|
||||
task: nil,
|
||||
defaultCode: 0x11, // ErrorCodeDatabaseError
|
||||
want: 0x11,
|
||||
},
|
||||
{
|
||||
name: "cancelled",
|
||||
task: &queue.Task{Status: "cancelled"},
|
||||
defaultCode: 0x00,
|
||||
want: 0x24, // ErrorCodeJobCancelled
|
||||
},
|
||||
{
|
||||
name: "oom",
|
||||
task: &queue.Task{Status: "failed", Error: "oom"},
|
||||
defaultCode: 0x00,
|
||||
want: 0x30, // ErrorCodeOutOfMemory
|
||||
},
|
||||
{
|
||||
name: "timeout",
|
||||
task: &queue.Task{Status: "failed", Error: "timeout"},
|
||||
defaultCode: 0x00,
|
||||
want: 0x14, // ErrorCodeTimeout
|
||||
},
|
||||
{
|
||||
name: "generic failed",
|
||||
task: &queue.Task{Status: "failed", Error: "generic error"},
|
||||
defaultCode: 0x00,
|
||||
want: 0x23, // ErrorCodeJobExecutionFailed
|
||||
},
|
||||
{
|
||||
name: "unknown status returns default",
|
||||
task: &queue.Task{Status: "weird", Error: "unknown"},
|
||||
defaultCode: 0x01,
|
||||
want: 0x01,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := mapper.MapError(tt.task, tt.defaultCode)
|
||||
if got != tt.want {
|
||||
t.Errorf("MapError() = 0x%02X, want 0x%02X", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResourceRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
want *helpers.ResourceRequest
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty payload",
|
||||
payload: []byte{},
|
||||
want: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil payload",
|
||||
payload: nil,
|
||||
want: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid minimal",
|
||||
payload: []byte{4, 8, 1, 0},
|
||||
want: &helpers.ResourceRequest{CPU: 4, MemoryGB: 8, GPU: 1, GPUMemory: ""},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid with gpu memory",
|
||||
payload: []byte{8, 16, 2, 4, '8', 'G', 'B', '!'},
|
||||
want: &helpers.ResourceRequest{CPU: 8, MemoryGB: 16, GPU: 2, GPUMemory: "8GB!"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too short",
|
||||
payload: []byte{1, 2},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid gpu mem length",
|
||||
payload: []byte{1, 2, 1, 10, 'a'},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := helpers.ParseResourceRequest(tt.payload)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseResourceRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.want == nil {
|
||||
if got != nil {
|
||||
t.Errorf("ParseResourceRequest() = %v, want nil", got)
|
||||
}
|
||||
} else if got == nil {
|
||||
t.Errorf("ParseResourceRequest() = nil, want %v", tt.want)
|
||||
} else {
|
||||
if got.CPU != tt.want.CPU {
|
||||
t.Errorf("CPU = %d, want %d", got.CPU, tt.want.CPU)
|
||||
}
|
||||
if got.MemoryGB != tt.want.MemoryGB {
|
||||
t.Errorf("MemoryGB = %d, want %d", got.MemoryGB, tt.want.MemoryGB)
|
||||
}
|
||||
if got.GPU != tt.want.GPU {
|
||||
t.Errorf("GPU = %d, want %d", got.GPU, tt.want.GPU)
|
||||
}
|
||||
if got.GPUMemory != tt.want.GPUMemory {
|
||||
t.Errorf("GPUMemory = %q, want %q", got.GPUMemory, tt.want.GPUMemory)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalJSONOrEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
name: "simple map",
|
||||
data: map[string]string{"key": "value"},
|
||||
want: []byte(`{"key":"value"}`),
|
||||
},
|
||||
{
|
||||
name: "string slice",
|
||||
data: []string{"a", "b", "c"},
|
||||
want: []byte(`["a","b","c"]`),
|
||||
},
|
||||
{
|
||||
name: "empty slice",
|
||||
data: []int{},
|
||||
want: []byte(`[]`),
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
data: nil,
|
||||
want: []byte(`null`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.MarshalJSONOrEmpty(tt.data)
|
||||
// For valid JSON, compare strings since JSON formatting might vary
|
||||
gotStr := string(got)
|
||||
wantStr := string(tt.want)
|
||||
if gotStr != wantStr {
|
||||
t.Errorf("MarshalJSONOrEmpty() = %q, want %q", gotStr, wantStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalJSONOrEmpty_ErrorCase(t *testing.T) {
|
||||
// Test with a value that can't be marshaled (function)
|
||||
got := helpers.MarshalJSONOrEmpty(func() {})
|
||||
want := []byte("[]")
|
||||
if string(got) != string(want) {
|
||||
t.Errorf("MarshalJSONOrEmpty() with invalid data = %q, want %q", string(got), string(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalJSONBytes(t *testing.T) {
|
||||
data := map[string]int{"count": 42}
|
||||
got, err := helpers.MarshalJSONBytes(data)
|
||||
if err != nil {
|
||||
t.Errorf("MarshalJSONBytes() unexpected error: %v", err)
|
||||
}
|
||||
want := `{"count":42}`
|
||||
if string(got) != want {
|
||||
t.Errorf("MarshalJSONBytes() = %q, want %q", string(got), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEmptyJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
want bool
|
||||
}{
|
||||
{"empty", []byte{}, true},
|
||||
{"null", []byte("null"), true},
|
||||
{"empty array", []byte("[]"), true},
|
||||
{"empty object", []byte("{}"), true},
|
||||
{"whitespace", []byte(" "), true},
|
||||
{"data", []byte(`{"key":"value"}`), false},
|
||||
{"non-empty array", []byte("[1,2,3]"), false},
|
||||
{"null with spaces", []byte(" null "), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := helpers.IsEmptyJSON(tt.data); got != tt.want {
|
||||
t.Errorf("IsEmptyJSON() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
486
tests/unit/api/helpers/validation_helpers_test.go
Normal file
486
tests/unit/api/helpers/validation_helpers_test.go
Normal file
|
|
@ -0,0 +1,486 @@
|
|||
package helpers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
func TestValidateCommitIDFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
commitID string
|
||||
wantOk bool
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid commit ID",
|
||||
commitID: "aabbccddeeff00112233445566778899aabbccdd",
|
||||
wantOk: true,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "too short",
|
||||
commitID: "aabbcc",
|
||||
wantOk: false,
|
||||
wantErr: "invalid commit_id length",
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
commitID: "aabbccddeeff00112233445566778899aabbccddeeff",
|
||||
wantOk: false,
|
||||
wantErr: "invalid commit_id length",
|
||||
},
|
||||
{
|
||||
name: "invalid hex",
|
||||
commitID: "gggggggggggggggggggggggggggggggggggggggg",
|
||||
wantOk: false,
|
||||
wantErr: "invalid commit_id hex",
|
||||
},
|
||||
{
|
||||
name: "mixed case valid",
|
||||
commitID: "AABBCCDDEEFF00112233445566778899AABBCCDD",
|
||||
wantOk: true,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
commitID: "",
|
||||
wantOk: false,
|
||||
wantErr: "invalid commit_id length",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotOk, gotErr := helpers.ValidateCommitIDFormat(tt.commitID)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("ValidateCommitIDFormat() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
if gotErr != tt.wantErr {
|
||||
t.Errorf("ValidateCommitIDFormat() err = %q, want %q", gotErr, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldRequireRunManifest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
want bool
|
||||
}{
|
||||
{"running", "running", true},
|
||||
{"completed", "completed", true},
|
||||
{"failed", "failed", true},
|
||||
{"queued", "queued", false},
|
||||
{"pending", "pending", false},
|
||||
{"cancelled", "cancelled", false},
|
||||
{"unknown", "unknown", false},
|
||||
{"empty", "", false},
|
||||
{"RUNNING uppercase", "RUNNING", true},
|
||||
{"Completed mixed", "Completed", true},
|
||||
{" Running with space", " running", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
task := &queue.Task{Status: tt.status}
|
||||
got := helpers.ShouldRequireRunManifest(task)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldRequireRunManifest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpectedRunManifestBucketForStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
wantBucket string
|
||||
wantOk bool
|
||||
}{
|
||||
{"queued", "queued", "pending", true},
|
||||
{"pending", "pending", "pending", true},
|
||||
{"running", "running", "running", true},
|
||||
{"completed", "completed", "finished", true},
|
||||
{"finished", "finished", "finished", true},
|
||||
{"failed", "failed", "failed", true},
|
||||
{"unknown", "unknown", "", false},
|
||||
{"empty", "", "", false},
|
||||
{"RUNNING uppercase", "RUNNING", "running", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBucket, gotOk := helpers.ExpectedRunManifestBucketForStatus(tt.status)
|
||||
if gotBucket != tt.wantBucket {
|
||||
t.Errorf("ExpectedRunManifestBucketForStatus() bucket = %q, want %q", gotBucket, tt.wantBucket)
|
||||
}
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("ExpectedRunManifestBucketForStatus() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTaskIDMatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rmTaskID string
|
||||
expectedID string
|
||||
wantOk bool
|
||||
wantExpected string
|
||||
wantActual string
|
||||
}{
|
||||
{
|
||||
name: "match",
|
||||
rmTaskID: "task-123",
|
||||
expectedID: "task-123",
|
||||
wantOk: true,
|
||||
wantExpected: "task-123",
|
||||
wantActual: "task-123",
|
||||
},
|
||||
{
|
||||
name: "mismatch",
|
||||
rmTaskID: "task-123",
|
||||
expectedID: "task-456",
|
||||
wantOk: false,
|
||||
wantExpected: "task-456",
|
||||
wantActual: "task-123",
|
||||
},
|
||||
{
|
||||
name: "empty rm task ID",
|
||||
rmTaskID: "",
|
||||
expectedID: "task-123",
|
||||
wantOk: false,
|
||||
wantExpected: "task-123",
|
||||
wantActual: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace trimmed - match",
|
||||
rmTaskID: "task-123",
|
||||
expectedID: "task-123",
|
||||
wantOk: true,
|
||||
wantExpected: "task-123",
|
||||
wantActual: "task-123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rm := &manifest.RunManifest{TaskID: tt.rmTaskID}
|
||||
got := helpers.ValidateTaskIDMatch(rm, tt.expectedID)
|
||||
if got.OK != tt.wantOk {
|
||||
t.Errorf("ValidateTaskIDMatch() OK = %v, want %v", got.OK, tt.wantOk)
|
||||
}
|
||||
if got.Expected != tt.wantExpected {
|
||||
t.Errorf("ValidateTaskIDMatch() Expected = %q, want %q", got.Expected, tt.wantExpected)
|
||||
}
|
||||
if got.Actual != tt.wantActual {
|
||||
t.Errorf("ValidateTaskIDMatch() Actual = %q, want %q", got.Actual, tt.wantActual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCommitIDMatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rmCommitID string
|
||||
expectedID string
|
||||
wantOk bool
|
||||
wantExpected string
|
||||
wantActual string
|
||||
}{
|
||||
{
|
||||
name: "both empty",
|
||||
rmCommitID: "",
|
||||
expectedID: "",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "",
|
||||
},
|
||||
{
|
||||
name: "match",
|
||||
rmCommitID: "abc123",
|
||||
expectedID: "abc123",
|
||||
wantOk: true,
|
||||
wantExpected: "abc123",
|
||||
wantActual: "abc123",
|
||||
},
|
||||
{
|
||||
name: "mismatch",
|
||||
rmCommitID: "abc123",
|
||||
expectedID: "def456",
|
||||
wantOk: false,
|
||||
wantExpected: "def456",
|
||||
wantActual: "abc123",
|
||||
},
|
||||
{
|
||||
name: "expected empty",
|
||||
rmCommitID: "abc123",
|
||||
expectedID: "",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "",
|
||||
},
|
||||
{
|
||||
name: "rm empty",
|
||||
rmCommitID: "",
|
||||
expectedID: "abc123",
|
||||
wantOk: true,
|
||||
wantExpected: "abc123",
|
||||
wantActual: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace trimmed",
|
||||
rmCommitID: " abc123 ",
|
||||
expectedID: "abc123",
|
||||
wantOk: true,
|
||||
wantExpected: "abc123",
|
||||
wantActual: "abc123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.ValidateCommitIDMatch(tt.rmCommitID, tt.expectedID)
|
||||
if got.OK != tt.wantOk {
|
||||
t.Errorf("ValidateCommitIDMatch() OK = %v, want %v", got.OK, tt.wantOk)
|
||||
}
|
||||
if got.Expected != tt.wantExpected {
|
||||
t.Errorf("ValidateCommitIDMatch() Expected = %q, want %q", got.Expected, tt.wantExpected)
|
||||
}
|
||||
if got.Actual != tt.wantActual {
|
||||
t.Errorf("ValidateCommitIDMatch() Actual = %q, want %q", got.Actual, tt.wantActual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDepsProvenance(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantName string
|
||||
wantSHA string
|
||||
gotName string
|
||||
gotSHA string
|
||||
wantOk bool
|
||||
wantExpected string
|
||||
wantActual string
|
||||
}{
|
||||
{
|
||||
name: "match",
|
||||
wantName: "requirements.txt",
|
||||
wantSHA: "abc123",
|
||||
gotName: "requirements.txt",
|
||||
gotSHA: "abc123",
|
||||
wantOk: true,
|
||||
wantExpected: "requirements.txt:abc123",
|
||||
wantActual: "requirements.txt:abc123",
|
||||
},
|
||||
{
|
||||
name: "name mismatch",
|
||||
wantName: "requirements.txt",
|
||||
wantSHA: "abc123",
|
||||
gotName: "Pipfile",
|
||||
gotSHA: "abc123",
|
||||
wantOk: false,
|
||||
wantExpected: "requirements.txt:abc123",
|
||||
wantActual: "Pipfile:abc123",
|
||||
},
|
||||
{
|
||||
name: "sha mismatch",
|
||||
wantName: "requirements.txt",
|
||||
wantSHA: "abc123",
|
||||
gotName: "requirements.txt",
|
||||
gotSHA: "def456",
|
||||
wantOk: false,
|
||||
wantExpected: "requirements.txt:abc123",
|
||||
wantActual: "requirements.txt:def456",
|
||||
},
|
||||
{
|
||||
name: "want empty",
|
||||
wantName: "",
|
||||
wantSHA: "",
|
||||
gotName: "requirements.txt",
|
||||
gotSHA: "abc123",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "",
|
||||
},
|
||||
{
|
||||
name: "got empty",
|
||||
wantName: "requirements.txt",
|
||||
wantSHA: "abc123",
|
||||
gotName: "",
|
||||
gotSHA: "",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
wantName: "",
|
||||
wantSHA: "",
|
||||
gotName: "",
|
||||
gotSHA: "",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.ValidateDepsProvenance(tt.wantName, tt.wantSHA, tt.gotName, tt.gotSHA)
|
||||
if got.OK != tt.wantOk {
|
||||
t.Errorf("ValidateDepsProvenance() OK = %v, want %v", got.OK, tt.wantOk)
|
||||
}
|
||||
if got.Expected != tt.wantExpected {
|
||||
t.Errorf("ValidateDepsProvenance() Expected = %q, want %q", got.Expected, tt.wantExpected)
|
||||
}
|
||||
if got.Actual != tt.wantActual {
|
||||
t.Errorf("ValidateDepsProvenance() Actual = %q, want %q", got.Actual, tt.wantActual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSnapshotID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantID string
|
||||
gotID string
|
||||
wantOk bool
|
||||
wantExpected string
|
||||
wantActual string
|
||||
}{
|
||||
{
|
||||
name: "match",
|
||||
wantID: "snap-123",
|
||||
gotID: "snap-123",
|
||||
wantOk: true,
|
||||
wantExpected: "snap-123",
|
||||
wantActual: "snap-123",
|
||||
},
|
||||
{
|
||||
name: "mismatch",
|
||||
wantID: "snap-123",
|
||||
gotID: "snap-456",
|
||||
wantOk: false,
|
||||
wantExpected: "snap-123",
|
||||
wantActual: "snap-456",
|
||||
},
|
||||
{
|
||||
name: "want empty",
|
||||
wantID: "",
|
||||
gotID: "snap-123",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "snap-123",
|
||||
},
|
||||
{
|
||||
name: "got empty",
|
||||
wantID: "snap-123",
|
||||
gotID: "",
|
||||
wantOk: true,
|
||||
wantExpected: "snap-123",
|
||||
wantActual: "",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
wantID: "",
|
||||
gotID: "",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.ValidateSnapshotID(tt.wantID, tt.gotID)
|
||||
if got.OK != tt.wantOk {
|
||||
t.Errorf("ValidateSnapshotID() OK = %v, want %v", got.OK, tt.wantOk)
|
||||
}
|
||||
if got.Expected != tt.wantExpected {
|
||||
t.Errorf("ValidateSnapshotID() Expected = %q, want %q", got.Expected, tt.wantExpected)
|
||||
}
|
||||
if got.Actual != tt.wantActual {
|
||||
t.Errorf("ValidateSnapshotID() Actual = %q, want %q", got.Actual, tt.wantActual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSnapshotSHA(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantSHA string
|
||||
gotSHA string
|
||||
wantOk bool
|
||||
wantExpected string
|
||||
wantActual string
|
||||
}{
|
||||
{
|
||||
name: "match",
|
||||
wantSHA: "sha256:abc123",
|
||||
gotSHA: "sha256:abc123",
|
||||
wantOk: true,
|
||||
wantExpected: "sha256:abc123",
|
||||
wantActual: "sha256:abc123",
|
||||
},
|
||||
{
|
||||
name: "mismatch",
|
||||
wantSHA: "sha256:abc123",
|
||||
gotSHA: "sha256:def456",
|
||||
wantOk: false,
|
||||
wantExpected: "sha256:abc123",
|
||||
wantActual: "sha256:def456",
|
||||
},
|
||||
{
|
||||
name: "want empty",
|
||||
wantSHA: "",
|
||||
gotSHA: "sha256:abc123",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "sha256:abc123",
|
||||
},
|
||||
{
|
||||
name: "got empty",
|
||||
wantSHA: "sha256:abc123",
|
||||
gotSHA: "",
|
||||
wantOk: true,
|
||||
wantExpected: "sha256:abc123",
|
||||
wantActual: "",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
wantSHA: "",
|
||||
gotSHA: "",
|
||||
wantOk: true,
|
||||
wantExpected: "",
|
||||
wantActual: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := helpers.ValidateSnapshotSHA(tt.wantSHA, tt.gotSHA)
|
||||
if got.OK != tt.wantOk {
|
||||
t.Errorf("ValidateSnapshotSHA() OK = %v, want %v", got.OK, tt.wantOk)
|
||||
}
|
||||
if got.Expected != tt.wantExpected {
|
||||
t.Errorf("ValidateSnapshotSHA() Expected = %q, want %q", got.Expected, tt.wantExpected)
|
||||
}
|
||||
if got.Actual != tt.wantActual {
|
||||
t.Errorf("ValidateSnapshotSHA() Actual = %q, want %q", got.Actual, tt.wantActual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue