package security import ( "fmt" "os" "path/filepath" "testing" "github.com/jfraeys/fetch_ml/internal/worker" ) func TestHIPAAValidation_NetworkMode(t *testing.T) { tests := []struct { name string networkMode string wantErr bool errContains string }{ { name: "HIPAA mode requires network_mode=none", networkMode: "bridge", wantErr: true, errContains: "network_mode must be 'none'", }, { name: "HIPAA mode accepts network_mode=none", networkMode: "none", wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &worker.Config{ ComplianceMode: "hipaa", GPUVendor: "none", Sandbox: worker.SandboxConfig{ NetworkMode: tt.networkMode, SeccompProfile: "default-hardened", NoNewPrivileges: true, }, MaxWorkers: 1, } cfg.Sandbox.ApplySecurityDefaults() err := cfg.Validate() if tt.wantErr { if err == nil { t.Errorf("expected error containing %q, got nil", tt.errContains) } else if !contains(err.Error(), tt.errContains) { t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error()) } } else { if err != nil { t.Errorf("unexpected error: %v", err) } } }) } } func TestHIPAAValidation_NoNewPrivileges(t *testing.T) { cfg := &worker.Config{ ComplianceMode: "hipaa", GPUVendor: "none", Sandbox: worker.SandboxConfig{ NetworkMode: "none", SeccompProfile: "default-hardened", NoNewPrivileges: false, // Violation - must stay false }, MaxWorkers: 1, } // Note: Don't call ApplySecurityDefaults() as it would reset NoNewPrivileges to true err := cfg.Validate() if err == nil { t.Error("expected error for NoNewPrivileges=false in HIPAA mode, got nil") } else if !contains(err.Error(), "no_new_privileges must be true") { t.Errorf("expected error about no_new_privileges, got %q", err.Error()) } } func TestHIPAAValidation_SeccompProfile(t *testing.T) { cfg := &worker.Config{ ComplianceMode: "hipaa", GPUVendor: "none", Sandbox: worker.SandboxConfig{ NetworkMode: "none", SeccompProfile: "", // Empty - violation NoNewPrivileges: true, }, MaxWorkers: 1, } err := cfg.Validate() if err == nil { t.Error("expected error for empty SeccompProfile in HIPAA mode, got nil") } else if !contains(err.Error(), "seccomp_profile must be non-empty") { t.Errorf("expected error about seccomp_profile, got %q", err.Error()) } } func TestHIPAAValidation_InlineCredentials(t *testing.T) { tests := []struct { name string setupFunc func(*worker.Config) errContains string }{ { name: "inline redis_password rejected", setupFunc: func(cfg *worker.Config) { cfg.RedisPassword = "plaintext_password" }, errContains: "redis_password must use ${ENV_VAR}", }, { name: "env reference redis_password accepted", setupFunc: func(cfg *worker.Config) { cfg.RedisPassword = "${REDIS_PASSWORD}" }, errContains: "", // No error expected }, { name: "inline ssh_key rejected", setupFunc: func(cfg *worker.Config) { cfg.SSHKey = "/path/to/key" }, errContains: "ssh_key must use ${ENV_VAR}", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &worker.Config{ ComplianceMode: "hipaa", GPUVendor: "none", Sandbox: worker.SandboxConfig{ NetworkMode: "none", SeccompProfile: "default-hardened", NoNewPrivileges: true, }, MaxWorkers: 1, } tt.setupFunc(cfg) cfg.Sandbox.ApplySecurityDefaults() err := cfg.Validate() if tt.errContains != "" { if err == nil { t.Errorf("expected error containing %q, got nil", tt.errContains) } else if !contains(err.Error(), tt.errContains) { t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error()) } } else { // For env reference test, we expect validation to pass // (but may fail on other checks like empty env var) if err != nil && contains(err.Error(), "must use ${ENV_VAR}") { t.Errorf("unexpected credential error: %v", err) } } }) } } func TestPHIDenylist_Validation(t *testing.T) { tests := []struct { name string allowedSecrets []string wantErr bool errContains string }{ { name: "normal secrets allowed", allowedSecrets: []string{"HF_TOKEN", "WANDB_API_KEY", "AWS_ACCESS_KEY"}, wantErr: false, }, { name: "patient_id rejected", allowedSecrets: []string{"PATIENT_ID"}, wantErr: true, errContains: "PHI field", }, { name: "ssn rejected", allowedSecrets: []string{"SSN"}, wantErr: true, errContains: "PHI field", }, { name: "medical_record_number rejected", allowedSecrets: []string{"MEDICAL_RECORD_NUMBER"}, wantErr: true, errContains: "PHI field", }, { name: "diagnosis_code rejected", allowedSecrets: []string{"DIAGNOSIS_CODE"}, wantErr: true, errContains: "PHI field", }, { name: "mixed secrets with phi rejected", allowedSecrets: []string{"HF_TOKEN", "PATIENT_SSN", "API_KEY"}, wantErr: true, errContains: "PHI field", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &worker.Config{ ComplianceMode: "hipaa", GPUVendor: "none", Sandbox: worker.SandboxConfig{ NetworkMode: "none", SeccompProfile: "default-hardened", NoNewPrivileges: true, AllowedSecrets: tt.allowedSecrets, }, MaxWorkers: 1, } cfg.Sandbox.ApplySecurityDefaults() err := cfg.Validate() if tt.wantErr { if err == nil { t.Errorf("expected error containing %q, got nil", tt.errContains) } else if !contains(err.Error(), tt.errContains) { t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error()) } } else { if err != nil { t.Errorf("unexpected error: %v", err) } } }) } } func TestArtifactIngestionCaps(t *testing.T) { tests := []struct { name string maxFiles int maxBytes int64 createFiles int fileSize int64 wantErr bool errContains string }{ { name: "within file cap", maxFiles: 10, maxBytes: 10000, createFiles: 5, fileSize: 100, wantErr: false, }, { name: "exceeds file cap", maxFiles: 3, maxBytes: 10000, createFiles: 5, fileSize: 100, wantErr: true, errContains: "file count cap exceeded", }, { name: "exceeds size cap", maxFiles: 100, maxBytes: 500, createFiles: 10, fileSize: 100, wantErr: true, errContains: "total size cap exceeded", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create temp directory with test files runDir := t.TempDir() for i := 0; i < tt.createFiles; i++ { f := createTestFile(t, runDir, i, tt.fileSize) t.Logf("Created: %s", f) } caps := &worker.SandboxConfig{ MaxArtifactFiles: tt.maxFiles, MaxArtifactTotalBytes: tt.maxBytes, } _, err := worker.ScanArtifacts(runDir, false, caps) if tt.wantErr { if err == nil { t.Errorf("expected error containing %q, got nil", tt.errContains) } else if !contains(err.Error(), tt.errContains) { t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error()) } } else { if err != nil { t.Errorf("unexpected error: %v", err) } } }) } } func TestConfigHash_Computation(t *testing.T) { cfg := &worker.Config{ Host: "localhost", Port: 22, MaxWorkers: 4, GPUVendor: "nvidia", Sandbox: worker.SandboxConfig{ NetworkMode: "none", SeccompProfile: "default-hardened", }, ComplianceMode: "hipaa", ProvenanceBestEffort: false, } hash1, err := cfg.ComputeResolvedConfigHash() if err != nil { t.Fatalf("ComputeResolvedConfigHash failed: %v", err) } if hash1 == "" { t.Error("expected non-empty hash") } // Same config should produce same hash hash2, err := cfg.ComputeResolvedConfigHash() if err != nil { t.Fatalf("ComputeResolvedConfigHash failed: %v", err) } if hash1 != hash2 { t.Error("same config should produce same hash") } // Different config should produce different hash cfg.MaxWorkers = 8 hash3, err := cfg.ComputeResolvedConfigHash() if err != nil { t.Fatalf("ComputeResolvedConfigHash failed: %v", err) } if hash1 == hash3 { t.Error("different config should produce different hash") } } // Helper functions func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(substr) == 0 || containsInternal(s, substr)) } func containsInternal(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } func createTestFile(t *testing.T, dir string, index int, size int64) string { t.Helper() fname := filepath.Join(dir, fmt.Sprintf("file_%d.txt", index)) data := make([]byte, size) if err := os.WriteFile(fname, data, 0640); err != nil { t.Fatalf("failed to create test file: %v", err) } return fname }