Add compliance_mode field to Config with strict HIPAA validation: - Requires SnapshotStore.Secure=true in HIPAA mode - Requires NetworkMode="none" for tenant isolation - Requires non-empty SeccompProfile - Requires NoNewPrivileges=true - Enforces credentials via environment variables only (no inline YAML) Add PHI denylist validation for AllowedSecrets: - Blocks secrets matching patterns: patient, ssn, mrn, medical_record, diagnosis, dob, birth, mrn_number, patient_id, patient_name - Prevents accidental PHI exfiltration via secret channels Add comprehensive test coverage in hipaa_validation_test.go: - Network mode enforcement tests - NoNewPrivileges requirement tests - Seccomp profile validation tests - Inline credential rejection tests - PHI denylist validation tests Closes: compliance_mode, PHI denylist items from security plan
372 lines
9.1 KiB
Go
372 lines
9.1 KiB
Go
package security
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
|
)
|
|
|
|
func TestHIPAAValidation_NetworkMode(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
networkMode string
|
|
wantErr bool
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "HIPAA mode requires network_mode=none",
|
|
networkMode: "bridge",
|
|
wantErr: true,
|
|
errContains: "network_mode must be 'none'",
|
|
},
|
|
{
|
|
name: "HIPAA mode accepts network_mode=none",
|
|
networkMode: "none",
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &worker.Config{
|
|
ComplianceMode: "hipaa",
|
|
GPUVendor: "none",
|
|
Sandbox: worker.SandboxConfig{
|
|
NetworkMode: tt.networkMode,
|
|
SeccompProfile: "default-hardened",
|
|
NoNewPrivileges: true,
|
|
},
|
|
MaxWorkers: 1,
|
|
}
|
|
cfg.Sandbox.ApplySecurityDefaults()
|
|
|
|
err := cfg.Validate()
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("expected error containing %q, got nil", tt.errContains)
|
|
} else if !contains(err.Error(), tt.errContains) {
|
|
t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error())
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHIPAAValidation_NoNewPrivileges(t *testing.T) {
|
|
cfg := &worker.Config{
|
|
ComplianceMode: "hipaa",
|
|
GPUVendor: "none",
|
|
Sandbox: worker.SandboxConfig{
|
|
NetworkMode: "none",
|
|
SeccompProfile: "default-hardened",
|
|
NoNewPrivileges: false, // Violation - must stay false
|
|
},
|
|
MaxWorkers: 1,
|
|
}
|
|
// Note: Don't call ApplySecurityDefaults() as it would reset NoNewPrivileges to true
|
|
|
|
err := cfg.Validate()
|
|
if err == nil {
|
|
t.Error("expected error for NoNewPrivileges=false in HIPAA mode, got nil")
|
|
} else if !contains(err.Error(), "no_new_privileges must be true") {
|
|
t.Errorf("expected error about no_new_privileges, got %q", err.Error())
|
|
}
|
|
}
|
|
|
|
func TestHIPAAValidation_SeccompProfile(t *testing.T) {
|
|
cfg := &worker.Config{
|
|
ComplianceMode: "hipaa",
|
|
GPUVendor: "none",
|
|
Sandbox: worker.SandboxConfig{
|
|
NetworkMode: "none",
|
|
SeccompProfile: "", // Empty - violation
|
|
NoNewPrivileges: true,
|
|
},
|
|
MaxWorkers: 1,
|
|
}
|
|
|
|
err := cfg.Validate()
|
|
if err == nil {
|
|
t.Error("expected error for empty SeccompProfile in HIPAA mode, got nil")
|
|
} else if !contains(err.Error(), "seccomp_profile must be non-empty") {
|
|
t.Errorf("expected error about seccomp_profile, got %q", err.Error())
|
|
}
|
|
}
|
|
|
|
func TestHIPAAValidation_InlineCredentials(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupFunc func(*worker.Config)
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "inline redis_password rejected",
|
|
setupFunc: func(cfg *worker.Config) {
|
|
cfg.RedisPassword = "plaintext_password"
|
|
},
|
|
errContains: "redis_password must use ${ENV_VAR}",
|
|
},
|
|
{
|
|
name: "env reference redis_password accepted",
|
|
setupFunc: func(cfg *worker.Config) {
|
|
cfg.RedisPassword = "${REDIS_PASSWORD}"
|
|
},
|
|
errContains: "", // No error expected
|
|
},
|
|
{
|
|
name: "inline ssh_key rejected",
|
|
setupFunc: func(cfg *worker.Config) {
|
|
cfg.SSHKey = "/path/to/key"
|
|
},
|
|
errContains: "ssh_key must use ${ENV_VAR}",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &worker.Config{
|
|
ComplianceMode: "hipaa",
|
|
GPUVendor: "none",
|
|
Sandbox: worker.SandboxConfig{
|
|
NetworkMode: "none",
|
|
SeccompProfile: "default-hardened",
|
|
NoNewPrivileges: true,
|
|
},
|
|
MaxWorkers: 1,
|
|
}
|
|
tt.setupFunc(cfg)
|
|
cfg.Sandbox.ApplySecurityDefaults()
|
|
|
|
err := cfg.Validate()
|
|
if tt.errContains != "" {
|
|
if err == nil {
|
|
t.Errorf("expected error containing %q, got nil", tt.errContains)
|
|
} else if !contains(err.Error(), tt.errContains) {
|
|
t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error())
|
|
}
|
|
} else {
|
|
// For env reference test, we expect validation to pass
|
|
// (but may fail on other checks like empty env var)
|
|
if err != nil && contains(err.Error(), "must use ${ENV_VAR}") {
|
|
t.Errorf("unexpected credential error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPHIDenylist_Validation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
allowedSecrets []string
|
|
wantErr bool
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "normal secrets allowed",
|
|
allowedSecrets: []string{"HF_TOKEN", "WANDB_API_KEY", "AWS_ACCESS_KEY"},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "patient_id rejected",
|
|
allowedSecrets: []string{"PATIENT_ID"},
|
|
wantErr: true,
|
|
errContains: "PHI field",
|
|
},
|
|
{
|
|
name: "ssn rejected",
|
|
allowedSecrets: []string{"SSN"},
|
|
wantErr: true,
|
|
errContains: "PHI field",
|
|
},
|
|
{
|
|
name: "medical_record_number rejected",
|
|
allowedSecrets: []string{"MEDICAL_RECORD_NUMBER"},
|
|
wantErr: true,
|
|
errContains: "PHI field",
|
|
},
|
|
{
|
|
name: "diagnosis_code rejected",
|
|
allowedSecrets: []string{"DIAGNOSIS_CODE"},
|
|
wantErr: true,
|
|
errContains: "PHI field",
|
|
},
|
|
{
|
|
name: "mixed secrets with phi rejected",
|
|
allowedSecrets: []string{"HF_TOKEN", "PATIENT_SSN", "API_KEY"},
|
|
wantErr: true,
|
|
errContains: "PHI field",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &worker.Config{
|
|
ComplianceMode: "hipaa",
|
|
GPUVendor: "none",
|
|
Sandbox: worker.SandboxConfig{
|
|
NetworkMode: "none",
|
|
SeccompProfile: "default-hardened",
|
|
NoNewPrivileges: true,
|
|
AllowedSecrets: tt.allowedSecrets,
|
|
},
|
|
MaxWorkers: 1,
|
|
}
|
|
cfg.Sandbox.ApplySecurityDefaults()
|
|
|
|
err := cfg.Validate()
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("expected error containing %q, got nil", tt.errContains)
|
|
} else if !contains(err.Error(), tt.errContains) {
|
|
t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error())
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestArtifactIngestionCaps(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
maxFiles int
|
|
maxBytes int64
|
|
createFiles int
|
|
fileSize int64
|
|
wantErr bool
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "within file cap",
|
|
maxFiles: 10,
|
|
maxBytes: 10000,
|
|
createFiles: 5,
|
|
fileSize: 100,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "exceeds file cap",
|
|
maxFiles: 3,
|
|
maxBytes: 10000,
|
|
createFiles: 5,
|
|
fileSize: 100,
|
|
wantErr: true,
|
|
errContains: "file count cap exceeded",
|
|
},
|
|
{
|
|
name: "exceeds size cap",
|
|
maxFiles: 100,
|
|
maxBytes: 500,
|
|
createFiles: 10,
|
|
fileSize: 100,
|
|
wantErr: true,
|
|
errContains: "total size cap exceeded",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create temp directory with test files
|
|
runDir := t.TempDir()
|
|
for i := 0; i < tt.createFiles; i++ {
|
|
f := createTestFile(t, runDir, i, tt.fileSize)
|
|
t.Logf("Created: %s", f)
|
|
}
|
|
|
|
caps := &worker.SandboxConfig{
|
|
MaxArtifactFiles: tt.maxFiles,
|
|
MaxArtifactTotalBytes: tt.maxBytes,
|
|
}
|
|
|
|
_, err := worker.ScanArtifacts(runDir, false, caps)
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("expected error containing %q, got nil", tt.errContains)
|
|
} else if !contains(err.Error(), tt.errContains) {
|
|
t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error())
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConfigHash_Computation(t *testing.T) {
|
|
cfg := &worker.Config{
|
|
Host: "localhost",
|
|
Port: 22,
|
|
MaxWorkers: 4,
|
|
GPUVendor: "nvidia",
|
|
Sandbox: worker.SandboxConfig{
|
|
NetworkMode: "none",
|
|
SeccompProfile: "default-hardened",
|
|
},
|
|
ComplianceMode: "hipaa",
|
|
ProvenanceBestEffort: false,
|
|
}
|
|
|
|
hash1, err := cfg.ComputeResolvedConfigHash()
|
|
if err != nil {
|
|
t.Fatalf("ComputeResolvedConfigHash failed: %v", err)
|
|
}
|
|
if hash1 == "" {
|
|
t.Error("expected non-empty hash")
|
|
}
|
|
|
|
// Same config should produce same hash
|
|
hash2, err := cfg.ComputeResolvedConfigHash()
|
|
if err != nil {
|
|
t.Fatalf("ComputeResolvedConfigHash failed: %v", err)
|
|
}
|
|
if hash1 != hash2 {
|
|
t.Error("same config should produce same hash")
|
|
}
|
|
|
|
// Different config should produce different hash
|
|
cfg.MaxWorkers = 8
|
|
hash3, err := cfg.ComputeResolvedConfigHash()
|
|
if err != nil {
|
|
t.Fatalf("ComputeResolvedConfigHash failed: %v", err)
|
|
}
|
|
if hash1 == hash3 {
|
|
t.Error("different config should produce different hash")
|
|
}
|
|
}
|
|
|
|
// Helper functions
|
|
func contains(s, substr string) bool {
|
|
return len(s) >= len(substr) && (s == substr || len(substr) == 0 || containsInternal(s, substr))
|
|
}
|
|
|
|
func containsInternal(s, substr string) bool {
|
|
for i := 0; i <= len(s)-len(substr); i++ {
|
|
if s[i:i+len(substr)] == substr {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func createTestFile(t *testing.T, dir string, index int, size int64) string {
|
|
t.Helper()
|
|
fname := filepath.Join(dir, fmt.Sprintf("file_%d.txt", index))
|
|
data := make([]byte, size)
|
|
if err := os.WriteFile(fname, data, 0640); err != nil {
|
|
t.Fatalf("failed to create test file: %v", err)
|
|
}
|
|
return fname
|
|
}
|