test: modernize test suite for streamlined infrastructure
- Update E2E tests for consolidated docker-compose.test.yml - Remove references to obsolete logs-debug.yml - Enhance test fixtures and utilities - Improve integration test coverage for KMS, queue, scheduler - Update unit tests for config constants and worker execution - Modernize cleanup-status.sh with new Makefile targets
This commit is contained in:
parent
61081655d2
commit
5f53104fcd
18 changed files with 105 additions and 211 deletions
|
|
@ -81,16 +81,17 @@ docker system df --format "table {{.Type}}\t{{.TotalCount}}\t{{.Size}}\t{{.Recla
|
|||
echo ""
|
||||
if [ "$containers" -gt 0 ] || [ "$images" -gt 0 ] || [ "$networks" -gt 0 ] || [ "$volumes" -gt 0 ]; then
|
||||
log_warning "Resources found that can be cleaned up"
|
||||
echo " Quick cleanup: make self-cleanup"
|
||||
echo " Force cleanup: make self-cleanup --force"
|
||||
echo " Full cleanup: make self-cleanup --all"
|
||||
echo " Quick cleanup: make clean"
|
||||
echo " Force cleanup: make clean-release"
|
||||
echo " Full cleanup: ./scripts/release/cleanup.sh all"
|
||||
else
|
||||
log_success "No fetch_ml resources found - system is clean!"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Cleanup Commands ==="
|
||||
echo " make self-cleanup - Interactive cleanup"
|
||||
echo " make clean - Remove build artifacts"
|
||||
echo " make clean-release - Thorough pre-release cleanup"
|
||||
echo " ./scripts/cleanup.sh --dry-run - Preview what would be cleaned"
|
||||
echo " ./scripts/cleanup.sh --force - Force cleanup without prompts"
|
||||
echo " ./scripts/cleanup.sh --all - Clean everything including images"
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ func runMLExperimentPhase(t *testing.T, cliPath, cliConfigDir, testDir string) {
|
|||
t.Fatalf("Failed to create experiment dir: %v", err)
|
||||
}
|
||||
|
||||
trainScript := filepath.Join(expDir, "train.py")
|
||||
tEntrypoint := filepath.Join(expDir, "train.py")
|
||||
trainCode := `#!/usr/bin/env python3
|
||||
import json
|
||||
import sys
|
||||
|
|
@ -231,7 +231,7 @@ print("Training completed successfully!")
|
|||
print(f"Results: {results}")
|
||||
sys.exit(0)
|
||||
`
|
||||
if err := os.WriteFile(trainScript, []byte(trainCode), 0600); err != nil {
|
||||
if err := os.WriteFile(tEntrypoint, []byte(trainCode), 0600); err != nil {
|
||||
t.Fatalf("Failed to create train.py: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,49 +0,0 @@
|
|||
---
|
||||
# Docker Compose configuration for logs and debug E2E tests
|
||||
# Simplified version using pre-built golang image with source mount
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
api-server:
|
||||
image: golang:1.25-bookworm
|
||||
working_dir: /app
|
||||
command: >
|
||||
sh -c "
|
||||
go build -o api-server ./cmd/api-server/main.go &&
|
||||
./api-server --config /app/configs/api/dev.yaml
|
||||
"
|
||||
ports:
|
||||
- "9102:9101"
|
||||
environment:
|
||||
- LOG_LEVEL=debug
|
||||
- REDIS_ADDR=redis:6379
|
||||
volumes:
|
||||
- ../../:/app
|
||||
- api-logs:/logs
|
||||
- api-experiments:/data/experiments
|
||||
- api-active:/data/active
|
||||
- go-mod-cache:/go/pkg/mod
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "--spider", "http://localhost:9101/health"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
start_period: 30s
|
||||
|
||||
volumes:
|
||||
api-logs:
|
||||
api-experiments:
|
||||
api-active:
|
||||
go-mod-cache:
|
||||
|
|
@ -76,10 +76,10 @@ func TestExampleExecution(t *testing.T) {
|
|||
for _, project := range projects {
|
||||
t.Run(project, func(t *testing.T) {
|
||||
projectDir := filepath.Join(examplesDir, project)
|
||||
trainScript := filepath.Join(projectDir, "train.py")
|
||||
tEntrypoint := filepath.Join(projectDir, "train.py")
|
||||
|
||||
// Test script syntax by checking if it can be parsed
|
||||
output, err := executeCommand("python3", "-m", "py_compile", trainScript)
|
||||
output, err := executeCommand("python3", "-m", "py_compile", tEntrypoint)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to compile %s: %v", project, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ func TestMLProjectCompatibility(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create minimal files
|
||||
trainScript := filepath.Join(experimentDir, "train.py")
|
||||
tEntrypoint := filepath.Join(experimentDir, "train.py")
|
||||
trainCode := `#!/usr/bin/env python3
|
||||
import argparse, json, logging, time
|
||||
from pathlib import Path
|
||||
|
|
@ -90,7 +90,7 @@ if __name__ == "__main__":
|
|||
main()
|
||||
`
|
||||
|
||||
if err := os.WriteFile(trainScript, []byte(trainCode), 0600); err != nil {
|
||||
if err := os.WriteFile(tEntrypoint, []byte(trainCode), 0600); err != nil {
|
||||
t.Fatalf("Failed to create train.py: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
10
tests/fixtures/ml_templates.go
vendored
10
tests/fixtures/ml_templates.go
vendored
|
|
@ -4,7 +4,7 @@ package tests
|
|||
// MLProjectTemplate represents a template for creating ML projects
|
||||
type MLProjectTemplate struct {
|
||||
Name string
|
||||
TrainScript string
|
||||
Entrypoint string
|
||||
Requirements string
|
||||
}
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ type MLProjectTemplate struct {
|
|||
func ScikitLearnTemplate() MLProjectTemplate {
|
||||
return MLProjectTemplate{
|
||||
Name: "Scikit-learn",
|
||||
TrainScript: `#!/usr/bin/env python3
|
||||
Entrypoint: `#!/usr/bin/env python3
|
||||
import argparse, json, logging, time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
|
@ -77,7 +77,7 @@ pandas>=1.3.0
|
|||
func StatsModelsTemplate() MLProjectTemplate {
|
||||
return MLProjectTemplate{
|
||||
Name: "StatsModels",
|
||||
TrainScript: `#!/usr/bin/env python3
|
||||
Entrypoint: `#!/usr/bin/env python3
|
||||
import argparse, json, logging, time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
|
@ -158,7 +158,7 @@ numpy>=1.21.0
|
|||
func XGBoostTemplate() MLProjectTemplate {
|
||||
return MLProjectTemplate{
|
||||
Name: "XGBoost",
|
||||
TrainScript: `#!/usr/bin/env python3
|
||||
Entrypoint: `#!/usr/bin/env python3
|
||||
import argparse, json, logging, time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
|
@ -240,7 +240,7 @@ numpy>=1.21.0
|
|||
func PyTorchTemplate() MLProjectTemplate {
|
||||
return MLProjectTemplate{
|
||||
Name: "PyTorch",
|
||||
TrainScript: `#!/usr/bin/env python3
|
||||
Entrypoint: `#!/usr/bin/env python3
|
||||
import argparse, json, logging, time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
|
|
|||
2
tests/fixtures/test_utils.go
vendored
2
tests/fixtures/test_utils.go
vendored
|
|
@ -545,7 +545,7 @@ func CreateMLProject(t *testing.T, testDir, projectName string, template MLProje
|
|||
|
||||
// Create training script
|
||||
trainScript := filepath.Join(experimentDir, "train.py")
|
||||
if err := os.WriteFile(trainScript, []byte(template.TrainScript), 0600); err != nil {
|
||||
if err := os.WriteFile(trainScript, []byte(template.Entrypoint), 0600); err != nil {
|
||||
t.Fatalf("Failed to create train.py: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func TestIntegrationE2E(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create standard ML project files
|
||||
trainScript := filepath.Join(jobDir, "train.py")
|
||||
tEntrypoint := filepath.Join(jobDir, "train.py")
|
||||
requirementsFile := filepath.Join(jobDir, "requirements.txt")
|
||||
readmeFile := filepath.Join(jobDir, "README.md")
|
||||
|
||||
|
|
@ -108,7 +108,7 @@ if __name__ == "__main__":
|
|||
`
|
||||
|
||||
//nolint:gosec // G306: Script needs execute permissions
|
||||
if err := os.WriteFile(trainScript, []byte(trainCode), 0750); err != nil {
|
||||
if err := os.WriteFile(tEntrypoint, []byte(trainCode), 0750); err != nil {
|
||||
t.Fatalf("Failed to create train.py: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ python train.py --epochs 2 --lr 0.01 --output_dir ./results
|
|||
}
|
||||
|
||||
// Test 4: Execute job (zero-install style)
|
||||
if err := executeZeroInstallJob(mlServer, nextTask, jobBaseDir, trainScript); err != nil {
|
||||
if err := executeZeroInstallJob(mlServer, nextTask, jobBaseDir, tEntrypoint); err != nil {
|
||||
t.Fatalf("Failed to execute job: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -261,7 +261,7 @@ python train.py --epochs 2 --lr 0.01 --output_dir ./results
|
|||
}
|
||||
|
||||
// executeZeroInstallJob simulates zero-install job execution
|
||||
func executeZeroInstallJob(server *tests.MLServer, task *tests.Task, baseDir, trainScript string) error {
|
||||
func executeZeroInstallJob(server *tests.MLServer, task *tests.Task, baseDir, tEntrypoint string) error {
|
||||
// Move job to running directory
|
||||
pendingPath := filepath.Join(baseDir, "pending", task.JobName)
|
||||
runningPath := filepath.Join(baseDir, statusRunning, task.JobName)
|
||||
|
|
@ -278,7 +278,7 @@ func executeZeroInstallJob(server *tests.MLServer, task *tests.Task, baseDir, tr
|
|||
|
||||
cmd := fmt.Sprintf("cd %s && python3 %s --output_dir %s %s",
|
||||
runningPath,
|
||||
filepath.Base(trainScript),
|
||||
filepath.Base(tEntrypoint),
|
||||
outputDir,
|
||||
task.Args,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ package tests_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -18,6 +20,33 @@ func TestVaultProvider_Integration(t *testing.T) {
|
|||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Skip if Docker is not available
|
||||
if _, err := exec.LookPath("docker"); err != nil {
|
||||
t.Skip("Docker not available, skipping container-based test")
|
||||
}
|
||||
if err := exec.Command("docker", "ps").Run(); err != nil {
|
||||
t.Skip("Docker daemon not running, skipping container-based test")
|
||||
}
|
||||
// Testcontainers requires Docker socket access - skip if not available
|
||||
if os.Getenv("DOCKER_HOST") == "" {
|
||||
// Check for default Docker socket locations
|
||||
dockerSocketPaths := []string{
|
||||
"/var/run/docker.sock",
|
||||
"/run/docker.sock",
|
||||
os.Getenv("HOME") + "/.docker/run/docker.sock", // rootless Docker
|
||||
}
|
||||
socketFound := false
|
||||
for _, path := range dockerSocketPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
socketFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !socketFound {
|
||||
t.Skip("Docker socket not found, skipping container-based test")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start Vault container
|
||||
|
|
@ -83,6 +112,33 @@ func TestAWSKMSProvider_Integration(t *testing.T) {
|
|||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Skip if Docker is not available
|
||||
if _, err := exec.LookPath("docker"); err != nil {
|
||||
t.Skip("Docker not available, skipping container-based test")
|
||||
}
|
||||
if err := exec.Command("docker", "ps").Run(); err != nil {
|
||||
t.Skip("Docker daemon not running, skipping container-based test")
|
||||
}
|
||||
// Testcontainers requires Docker socket access - skip if not available
|
||||
if os.Getenv("DOCKER_HOST") == "" {
|
||||
// Check for default Docker socket locations
|
||||
dockerSocketPaths := []string{
|
||||
"/var/run/docker.sock",
|
||||
"/run/docker.sock",
|
||||
os.Getenv("HOME") + "/.docker/run/docker.sock", // rootless Docker
|
||||
}
|
||||
socketFound := false
|
||||
for _, path := range dockerSocketPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
socketFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !socketFound {
|
||||
t.Skip("Docker socket not found, skipping container-based test")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start LocalStack container with KMS
|
||||
|
|
@ -162,7 +218,7 @@ func TestTenantKeyManager_WithMemoryProvider(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create TenantKeyManager
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config)
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "test", true)
|
||||
|
|
|
|||
|
|
@ -1,116 +0,0 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api"
|
||||
)
|
||||
|
||||
func TestProtocolSerialization(t *testing.T) {
|
||||
// Test success packet
|
||||
successPacket := api.NewSuccessPacket("Operation completed successfully")
|
||||
data, err := successPacket.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize success packet: %v", err)
|
||||
}
|
||||
|
||||
// Verify packet type
|
||||
if len(data) < 1 || data[0] != api.PacketTypeSuccess {
|
||||
t.Errorf("Expected packet type %d, got %d", api.PacketTypeSuccess, data[0])
|
||||
}
|
||||
|
||||
// Verify timestamp is present (9 bytes minimum: 1 type + 8 timestamp)
|
||||
if len(data) < 9 {
|
||||
t.Errorf("Expected at least 9 bytes, got %d", len(data))
|
||||
}
|
||||
|
||||
// Test error packet
|
||||
errorPacket := api.NewErrorPacket(api.ErrorCodeAuthenticationFailed, "Auth failed", "Invalid API key")
|
||||
data, err = errorPacket.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize error packet: %v", err)
|
||||
}
|
||||
|
||||
if len(data) < 1 || data[0] != api.PacketTypeError {
|
||||
t.Errorf("Expected packet type %d, got %d", api.PacketTypeError, data[0])
|
||||
}
|
||||
|
||||
// Test progress packet
|
||||
progressPacket := api.NewProgressPacket(api.ProgressTypePercentage, 75, 100, "Processing...")
|
||||
data, err = progressPacket.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize progress packet: %v", err)
|
||||
}
|
||||
|
||||
if len(data) < 1 || data[0] != api.PacketTypeProgress {
|
||||
t.Errorf("Expected packet type %d, got %d", api.PacketTypeProgress, data[0])
|
||||
}
|
||||
|
||||
// Test status packet
|
||||
statusPacket := api.NewStatusPacket(`{"workers":1,"queued":0}`)
|
||||
data, err = statusPacket.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize status packet: %v", err)
|
||||
}
|
||||
|
||||
if len(data) < 1 || data[0] != api.PacketTypeStatus {
|
||||
t.Errorf("Expected packet type %d, got %d", api.PacketTypeStatus, data[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorMessageMapping(t *testing.T) {
|
||||
tests := map[byte]string{
|
||||
api.ErrorCodeUnknownError: "Unknown error occurred",
|
||||
api.ErrorCodeAuthenticationFailed: "Authentication failed",
|
||||
api.ErrorCodeJobNotFound: "Job not found",
|
||||
api.ErrorCodeServerOverloaded: "Server is overloaded",
|
||||
}
|
||||
|
||||
for code, expected := range tests {
|
||||
actual := api.GetErrorMessage(code)
|
||||
if actual != expected {
|
||||
t.Errorf("Expected error message '%s' for code %d, got '%s'", expected, code, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevelMapping(t *testing.T) {
|
||||
tests := map[byte]string{
|
||||
api.LogLevelDebug: "DEBUG",
|
||||
api.LogLevelInfo: "INFO",
|
||||
api.LogLevelWarn: "WARN",
|
||||
api.LogLevelError: "ERROR",
|
||||
}
|
||||
|
||||
for level, expected := range tests {
|
||||
actual := api.GetLogLevelName(level)
|
||||
if actual != expected {
|
||||
t.Errorf("Expected log level '%s' for level %d, got '%s'", expected, level, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampConsistency(t *testing.T) {
|
||||
before := time.Now().Unix()
|
||||
|
||||
packet := api.NewSuccessPacket("Test message")
|
||||
data, err := packet.Serialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize: %v", err)
|
||||
}
|
||||
|
||||
after := time.Now().Unix()
|
||||
|
||||
// Extract timestamp (bytes 1-8, big-endian)
|
||||
if len(data) < 9 {
|
||||
t.Fatalf("Packet too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
timestamp := binary.BigEndian.Uint64(data[1:9])
|
||||
|
||||
if timestamp < uint64(before) || timestamp > uint64(after) {
|
||||
t.Errorf("Timestamp %d not in expected range [%d, %d]", timestamp, before, after)
|
||||
}
|
||||
}
|
||||
|
|
@ -228,12 +228,12 @@ func simulateJobExecution(t *testing.T, runningJobDir, jobName string, priority
|
|||
}
|
||||
}
|
||||
|
||||
func detectFramework(t *testing.T, trainScript string) string {
|
||||
func detectFramework(t *testing.T, entrypoint string) string {
|
||||
t.Helper()
|
||||
|
||||
scriptContent, err := fileutil.SecureFileRead(trainScript)
|
||||
scriptContent, err := fileutil.SecureFileRead(entrypoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read train.py: %v", err)
|
||||
t.Fatalf("Failed to read entrypoint script file: %v", err)
|
||||
}
|
||||
|
||||
scriptStr := string(scriptContent)
|
||||
|
|
@ -325,7 +325,7 @@ func TestQueueCapacity(t *testing.T) {
|
|||
// Copy actual example files
|
||||
if _, err := os.Stat(sourceDir); os.IsNotExist(err) {
|
||||
// Create minimal files if example doesn't exist
|
||||
trainScript := filepath.Join(jobDir, "train.py")
|
||||
entrypoint := filepath.Join(jobDir, "train.py")
|
||||
script := fmt.Sprintf(`#!/usr/bin/env python3
|
||||
import json, time
|
||||
from pathlib import Path
|
||||
|
|
@ -349,7 +349,7 @@ if __name__ == "__main__":
|
|||
`, i, example)
|
||||
|
||||
//nolint:gosec // G306: Script needs execute permissions
|
||||
if err := os.WriteFile(trainScript, []byte(script), 0750); err != nil {
|
||||
if err := os.WriteFile(entrypoint, []byte(script), 0750); err != nil {
|
||||
t.Fatalf("Failed to create train script for job %d: %v", i, err)
|
||||
}
|
||||
} else {
|
||||
|
|
@ -456,10 +456,10 @@ func TestResourceIsolation(t *testing.T) {
|
|||
sourceDir := examplesDir.GetPath(expName)
|
||||
|
||||
// Read actual example to create realistic results
|
||||
trainScript := filepath.Join(sourceDir, "train.py")
|
||||
entrypoint := filepath.Join(sourceDir, "train.py")
|
||||
|
||||
framework := "unknown"
|
||||
if content, err := fileutil.SecureFileRead(trainScript); err == nil {
|
||||
if content, err := fileutil.SecureFileRead(entrypoint); err == nil {
|
||||
scriptStr := string(content)
|
||||
switch {
|
||||
case contains(scriptStr, "sklearn"):
|
||||
|
|
|
|||
|
|
@ -260,9 +260,9 @@ func TestStarvationPrevention(t *testing.T) {
|
|||
msg1 := <-recvCh
|
||||
require.Equal(t, scheduler.MsgJobAssign, msg1.Type)
|
||||
|
||||
var spec1 scheduler.JobSpec
|
||||
json.Unmarshal(msg1.Payload, &spec1)
|
||||
assert.Equal(t, "high-priority-job", spec1.ID)
|
||||
var payload1 scheduler.JobAssignPayload
|
||||
json.Unmarshal(msg1.Payload, &payload1)
|
||||
assert.Equal(t, "high-priority-job", payload1.Spec.ID)
|
||||
|
||||
// Complete first job
|
||||
conn.WriteJSON(scheduler.Message{
|
||||
|
|
@ -286,9 +286,9 @@ func TestStarvationPrevention(t *testing.T) {
|
|||
msg2 := <-recvCh
|
||||
require.Equal(t, scheduler.MsgJobAssign, msg2.Type)
|
||||
|
||||
var spec2 scheduler.JobSpec
|
||||
json.Unmarshal(msg2.Payload, &spec2)
|
||||
assert.Equal(t, "low-priority-job", spec2.ID)
|
||||
var payload2 scheduler.JobAssignPayload
|
||||
json.Unmarshal(msg2.Payload, &payload2)
|
||||
assert.Equal(t, "low-priority-job", payload2.Spec.ID)
|
||||
}
|
||||
|
||||
// Helper function to create test worker with token auth
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ func TestWebSocketQueueEndToEnd(t *testing.T) {
|
|||
if testing.Short() {
|
||||
t.Skip("skipping websocket queue integration in short mode")
|
||||
}
|
||||
t.Parallel()
|
||||
|
||||
// Miniredis provides an in-memory Redis compatible server for realistic queue tests.
|
||||
redisServer, err := miniredis.Run()
|
||||
|
|
@ -136,6 +137,7 @@ func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
|
|||
if testing.Short() {
|
||||
t.Skip("skipping websocket queue integration in short mode")
|
||||
}
|
||||
t.Parallel()
|
||||
|
||||
queuePath := filepath.Join(t.TempDir(), "queue.db")
|
||||
taskQueue, err := queue.NewSQLiteQueue(queuePath)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func TestWorkerLocalMode(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create standard ML project files
|
||||
trainScript := filepath.Join(jobDir, "train.py")
|
||||
entrypoint := filepath.Join(jobDir, "train.py")
|
||||
requirementsFile := filepath.Join(jobDir, "requirements.txt")
|
||||
|
||||
// Create train.py (zero-install style)
|
||||
|
|
@ -96,7 +96,7 @@ if __name__ == "__main__":
|
|||
`
|
||||
|
||||
//nolint:gosec // G306: Script needs execute permissions
|
||||
if err := os.WriteFile(trainScript, []byte(trainCode), 0750); err != nil {
|
||||
if err := os.WriteFile(entrypoint, []byte(trainCode), 0750); err != nil {
|
||||
t.Fatalf("Failed to create train.py: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -137,7 +137,7 @@ numpy>=1.21.0
|
|||
}
|
||||
|
||||
// Test file operations
|
||||
output, err = server.Exec(fmt.Sprintf("test -f %s && echo 'exists'", trainScript))
|
||||
output, err = server.Exec(fmt.Sprintf("test -f %s && echo 'exists'", entrypoint))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to test file existence: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,10 +184,10 @@ output_dir: "/home/mluser/ml_jobs"
|
|||
}
|
||||
|
||||
// Verify uploaded files exist
|
||||
uploadedTrainScript := filepath.Join(
|
||||
uploadedEntrypoint := filepath.Join(
|
||||
testDir, "server", "home", "mluser", "ml_jobs", "pending",
|
||||
"my_experiment_20231201_143022", "train.py")
|
||||
if _, err := os.Stat(uploadedTrainScript); os.IsNotExist(err) {
|
||||
if _, err := os.Stat(uploadedEntrypoint); os.IsNotExist(err) {
|
||||
t.Error("Uploaded train.py should exist in pending directory")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ func TestDefaultConstants(t *testing.T) {
|
|||
{config.DefaultRedisPort, 6379, "DefaultRedisPort"},
|
||||
{config.DefaultRedisAddr, "localhost:6379", "DefaultRedisAddr"},
|
||||
{config.DefaultBasePath, "/mnt/nas/jobs", "DefaultBasePath"},
|
||||
{config.DefaultTrainScript, "train.py", "DefaultTrainScript"},
|
||||
{config.DefaultEntrypoint, "train.py", "DefaultEntrypoint"},
|
||||
{config.DefaultDataDir, "/data/active", "DefaultDataDir"},
|
||||
{config.DefaultLocalDataDir, "./data/active", "DefaultLocalDataDir"},
|
||||
{config.DefaultNASDataDir, "/mnt/datasets", "DefaultNASDataDir"},
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ func TestKMSProtocol_EncryptDecrypt(t *testing.T) {
|
|||
Cache: kms.DefaultCacheConfig(),
|
||||
}
|
||||
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config)
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
// Provision tenant
|
||||
hierarchy, err := tkm.ProvisionTenant("protocol-test-tenant")
|
||||
|
|
@ -189,7 +189,7 @@ func TestKMSProtocol_MultiTenantIsolation(t *testing.T) {
|
|||
Cache: kms.DefaultCacheConfig(),
|
||||
}
|
||||
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config)
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
// Provision two tenants
|
||||
tenant1, err := tkm.ProvisionTenant("tenant-1")
|
||||
|
|
@ -239,7 +239,7 @@ func TestKMSProtocol_CacheHit(t *testing.T) {
|
|||
Cache: kms.DefaultCacheConfig(),
|
||||
}
|
||||
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config)
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
hierarchy, _ := tkm.ProvisionTenant("cache-test")
|
||||
|
||||
|
|
@ -279,7 +279,7 @@ func TestKMSProtocol_KeyRotation(t *testing.T) {
|
|||
Cache: kms.DefaultCacheConfig(),
|
||||
}
|
||||
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config)
|
||||
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
|
||||
// Provision tenant
|
||||
hierarchy, _ := tkm.ProvisionTenant("rotation-test")
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ func TestRunManifest_WrittenForLocalModeRun(t *testing.T) {
|
|||
cfg := &worker.Config{
|
||||
BasePath: base,
|
||||
LocalMode: true,
|
||||
TrainScript: "train.py",
|
||||
Entrypoint: "train.py",
|
||||
PodmanImage: "python:3.11",
|
||||
WorkerID: "worker-test",
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue