diff --git a/internal/worker/config.go b/internal/worker/config.go index 3ca01f4..46ebb34 100644 --- a/internal/worker/config.go +++ b/internal/worker/config.go @@ -1,6 +1,9 @@ package worker import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" "fmt" "math" "net/url" @@ -73,6 +76,10 @@ type Config struct { // Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort. ProvenanceBestEffort bool `yaml:"provenance_best_effort"` + // Compliance mode: "hipaa", "standard", or empty + // When "hipaa": enforces hard requirements at startup + ComplianceMode string `yaml:"compliance_mode"` + // Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env). PrewarmEnabled bool `yaml:"prewarm_enabled"` @@ -150,35 +157,43 @@ type SandboxConfig struct { MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10 + + // Artifact ingestion caps (NEW) + MaxArtifactFiles int `yaml:"max_artifact_files"` // Default: 10000 + MaxArtifactTotalBytes int64 `yaml:"max_artifact_total_bytes"` // Default: 100GB } // SecurityDefaults holds default values for security configuration var SecurityDefaults = struct { - NetworkMode string - ReadOnlyRoot bool - AllowSecrets bool - SeccompProfile string - NoNewPrivileges bool - DropAllCaps bool - UserNS bool - RunAsUID int - RunAsGID int - MaxUploadSizeBytes int64 - MaxUploadRateBps int64 - MaxUploadsPerMinute int + NetworkMode string + ReadOnlyRoot bool + AllowSecrets bool + SeccompProfile string + NoNewPrivileges bool + DropAllCaps bool + UserNS bool + RunAsUID int + RunAsGID int + MaxUploadSizeBytes int64 + MaxUploadRateBps int64 + MaxUploadsPerMinute int + MaxArtifactFiles int + MaxArtifactTotalBytes int64 }{ - NetworkMode: "none", - ReadOnlyRoot: true, - AllowSecrets: false, - SeccompProfile: "default-hardened", - NoNewPrivileges: true, - DropAllCaps: true, - UserNS: true, - RunAsUID: 1000, - RunAsGID: 1000, - MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB - MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s - MaxUploadsPerMinute: 10, + NetworkMode: "none", + ReadOnlyRoot: true, + AllowSecrets: false, + SeccompProfile: "default-hardened", + NoNewPrivileges: true, + DropAllCaps: true, + UserNS: true, + RunAsUID: 1000, + RunAsGID: 1000, + MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB + MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s + MaxUploadsPerMinute: 10, + MaxArtifactFiles: 10000, + MaxArtifactTotalBytes: 100 * 1024 * 1024 * 1024, // 100GB } // Validate checks sandbox configuration @@ -258,6 +273,14 @@ func (s *SandboxConfig) ApplySecurityDefaults() { if s.MaxUploadsPerMinute == 0 { s.MaxUploadsPerMinute = SecurityDefaults.MaxUploadsPerMinute } + + // Artifact ingestion caps + if s.MaxArtifactFiles == 0 { + s.MaxArtifactFiles = SecurityDefaults.MaxArtifactFiles + } + if s.MaxArtifactTotalBytes == 0 { + s.MaxArtifactTotalBytes = SecurityDefaults.MaxArtifactTotalBytes + } } // Getter methods for SandboxConfig interface @@ -415,7 +438,7 @@ func LoadConfig(path string) (*Config, error) { cfg.Sandbox.ApplySecurityDefaults() // Expand secrets from environment variables - if err := cfg.expandSecrets(); err != nil { + if err := cfg.ExpandSecrets(); err != nil { return nil, fmt.Errorf("secrets expansion failed: %w", err) } @@ -567,11 +590,24 @@ func (c *Config) Validate() error { } } + // HIPAA mode validation - hard requirements + if strings.ToLower(c.ComplianceMode) == "hipaa" { + if err := c.validateHIPAARequirements(); err != nil { + return fmt.Errorf("HIPAA compliance validation failed: %w", err) + } + } + return nil } -// expandSecrets replaces secret placeholders with environment variables -func (c *Config) expandSecrets() error { +// ExpandSecrets replaces secret placeholders with environment variables +// Exported for testing purposes +func (c *Config) ExpandSecrets() error { + // First validate that secrets use env var syntax (not plaintext) + if err := c.ValidateNoPlaintextSecrets(); err != nil { + return err + } + // Expand Redis password from env if using ${...} syntax if strings.Contains(c.RedisPassword, "${") { c.RedisPassword = os.ExpandEnv(c.RedisPassword) @@ -588,17 +624,13 @@ func (c *Config) expandSecrets() error { c.SnapshotStore.SessionToken = os.ExpandEnv(c.SnapshotStore.SessionToken) } - // Validate no plaintext secrets remain in critical fields - if err := c.validateNoPlaintextSecrets(); err != nil { - return err - } - return nil } -// validateNoPlaintextSecrets checks that sensitive fields use env var references +// ValidateNoPlaintextSecrets checks that sensitive fields use env var references // rather than hardcoded plaintext values. This is a HIPAA compliance requirement. -func (c *Config) validateNoPlaintextSecrets() error { +// Exported for testing purposes +func (c *Config) ValidateNoPlaintextSecrets() error { // Fields that should use ${ENV_VAR} syntax instead of plaintext sensitiveFields := []struct { name string @@ -616,11 +648,11 @@ func (c *Config) validateNoPlaintextSecrets() error { } // Check if it looks like a plaintext secret (not env var reference) - if !strings.HasPrefix(field.value, "${") && looksLikeSecret(field.value) { + if !strings.HasPrefix(field.value, "${") && LooksLikeSecret(field.value) { return fmt.Errorf( "%s appears to contain a plaintext secret (length=%d, entropy=%.2f); "+ "use ${ENV_VAR} syntax to load from environment or secrets manager", - field.name, len(field.value), calculateEntropy(field.value), + field.name, len(field.value), CalculateEntropy(field.value), ) } } @@ -628,15 +660,94 @@ func (c *Config) validateNoPlaintextSecrets() error { return nil } -// looksLikeSecret heuristically detects if a string looks like a secret credential -func looksLikeSecret(s string) bool { +// validateHIPAARequirements enforces hard HIPAA compliance requirements at startup. +// These must fail loudly rather than silently fall back to insecure defaults. +func (c *Config) validateHIPAARequirements() error { + // 1. SnapshotStore must be secure + if c.SnapshotStore.Enabled && !c.SnapshotStore.Secure { + return fmt.Errorf("snapshot_store.secure must be true in HIPAA mode") + } + + // 2. NetworkMode must be "none" (no network access) + if c.Sandbox.NetworkMode != "none" { + return fmt.Errorf("sandbox.network_mode must be 'none' in HIPAA mode, got %q", c.Sandbox.NetworkMode) + } + + // 3. SeccompProfile must be non-empty + if c.Sandbox.SeccompProfile == "" { + return fmt.Errorf("sandbox.seccomp_profile must be non-empty in HIPAA mode") + } + + // 4. NoNewPrivileges must be true + if !c.Sandbox.NoNewPrivileges { + return fmt.Errorf("sandbox.no_new_privileges must be true in HIPAA mode") + } + + // 5. All credentials must be sourced from env vars, not inline YAML + if err := c.validateNoInlineCredentials(); err != nil { + return err + } + + // 6. AllowedSecrets must not contain PHI field names + if err := c.Sandbox.validatePHIDenylist(); err != nil { + return err + } + + return nil +} + +// validateNoInlineCredentials checks that no credentials are hardcoded in config +func (c *Config) validateNoInlineCredentials() error { + // Check Redis password - must be empty or use env var syntax + if c.RedisPassword != "" && !strings.HasPrefix(c.RedisPassword, "${") { + return fmt.Errorf("redis_password must use ${ENV_VAR} syntax in HIPAA mode, not inline value") + } + + // Check SSH key - must use env var syntax + if c.SSHKey != "" && !strings.HasPrefix(c.SSHKey, "${") { + return fmt.Errorf("ssh_key must use ${ENV_VAR} syntax in HIPAA mode, not inline value") + } + + // Check SnapshotStore credentials + if c.SnapshotStore.AccessKey != "" && !strings.HasPrefix(c.SnapshotStore.AccessKey, "${") { + return fmt.Errorf("snapshot_store.access_key must use ${ENV_VAR} syntax in HIPAA mode") + } + if c.SnapshotStore.SecretKey != "" && !strings.HasPrefix(c.SnapshotStore.SecretKey, "${") { + return fmt.Errorf("snapshot_store.secret_key must use ${ENV_VAR} syntax in HIPAA mode") + } + + return nil +} + +// PHI field patterns that should not appear in AllowedSecrets +var phiDenylistPatterns = []string{ + "patient", "phi", "ssn", "social_security", "mrn", "medical_record", + "dob", "birth_date", "diagnosis", "condition", "medication", "allergy", +} + +// validatePHIDenylist checks that AllowedSecrets doesn't contain PHI field names +func (s *SandboxConfig) validatePHIDenylist() error { + for _, secret := range s.AllowedSecrets { + secretLower := strings.ToLower(secret) + for _, pattern := range phiDenylistPatterns { + if strings.Contains(secretLower, pattern) { + return fmt.Errorf("allowed_secrets contains potential PHI field %q (matches pattern %q); this could allow PHI exfiltration", secret, pattern) + } + } + } + return nil +} + +// LooksLikeSecret heuristically detects if a string looks like a secret credential +// Exported for testing purposes +func LooksLikeSecret(s string) bool { // Minimum length for secrets if len(s) < 16 { return false } // Calculate entropy to detect high-entropy strings (likely secrets) - entropy := calculateEntropy(s) + entropy := CalculateEntropy(s) // High entropy (>4 bits per char) combined with reasonable length suggests a secret if entropy > 4.0 { @@ -664,8 +775,9 @@ func looksLikeSecret(s string) bool { return false } -// calculateEntropy calculates Shannon entropy of a string in bits per character -func calculateEntropy(s string) float64 { +// CalculateEntropy calculates Shannon entropy of a string in bits per character +// Exported for testing purposes +func CalculateEntropy(s string) float64 { if len(s) == 0 { return 0 } @@ -689,6 +801,54 @@ func calculateEntropy(s string) float64 { return entropy } +// ComputeResolvedConfigHash computes a SHA-256 hash of the resolved config. +// This must be called after os.ExpandEnv, after default application, and after Validate(). +// The hash captures the actual runtime configuration, not the raw YAML file. +// This is critical for reproducibility - two different raw files that resolve +// to the same config will produce the same hash. +func (c *Config) ComputeResolvedConfigHash() (string, error) { + // Marshal config to JSON for consistent serialization + // We use a simplified struct to avoid hashing volatile fields + hashable := struct { + Host string `json:"host"` + Port int `json:"port"` + BasePath string `json:"base_path"` + MaxWorkers int `json:"max_workers"` + Resources config.ResourceConfig `json:"resources"` + GPUVendor string `json:"gpu_vendor"` + GPUVisibleDevices []int `json:"gpu_visible_devices,omitempty"` + GPUVisibleDeviceIDs []string `json:"gpu_visible_device_ids,omitempty"` + Sandbox SandboxConfig `json:"sandbox"` + ComplianceMode string `json:"compliance_mode"` + ProvenanceBestEffort bool `json:"provenance_best_effort"` + SnapshotStoreSecure bool `json:"snapshot_store_secure,omitempty"` + QueueBackend string `json:"queue_backend"` + }{ + Host: c.Host, + Port: c.Port, + BasePath: c.BasePath, + MaxWorkers: c.MaxWorkers, + Resources: c.Resources, + GPUVendor: c.GPUVendor, + GPUVisibleDevices: c.GPUVisibleDevices, + GPUVisibleDeviceIDs: c.GPUVisibleDeviceIDs, + Sandbox: c.Sandbox, + ComplianceMode: c.ComplianceMode, + ProvenanceBestEffort: c.ProvenanceBestEffort, + SnapshotStoreSecure: c.SnapshotStore.Secure, + QueueBackend: c.Queue.Backend, + } + + data, err := json.Marshal(hashable) + if err != nil { + return "", fmt.Errorf("failed to marshal config for hashing: %w", err) + } + + // Compute SHA-256 hash + hash := sha256.Sum256(data) + return hex.EncodeToString(hash[:]), nil +} + // envInt reads an integer from environment variable func envInt(name string) (int, bool) { v := strings.TrimSpace(os.Getenv(name)) diff --git a/tests/unit/security/hipaa_validation_test.go b/tests/unit/security/hipaa_validation_test.go new file mode 100644 index 0000000..81212c7 --- /dev/null +++ b/tests/unit/security/hipaa_validation_test.go @@ -0,0 +1,372 @@ +package security + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/jfraeys/fetch_ml/internal/worker" +) + +func TestHIPAAValidation_NetworkMode(t *testing.T) { + tests := []struct { + name string + networkMode string + wantErr bool + errContains string + }{ + { + name: "HIPAA mode requires network_mode=none", + networkMode: "bridge", + wantErr: true, + errContains: "network_mode must be 'none'", + }, + { + name: "HIPAA mode accepts network_mode=none", + networkMode: "none", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &worker.Config{ + ComplianceMode: "hipaa", + GPUVendor: "none", + Sandbox: worker.SandboxConfig{ + NetworkMode: tt.networkMode, + SeccompProfile: "default-hardened", + NoNewPrivileges: true, + }, + MaxWorkers: 1, + } + cfg.Sandbox.ApplySecurityDefaults() + + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errContains) + } else if !contains(err.Error(), tt.errContains) { + t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func TestHIPAAValidation_NoNewPrivileges(t *testing.T) { + cfg := &worker.Config{ + ComplianceMode: "hipaa", + GPUVendor: "none", + Sandbox: worker.SandboxConfig{ + NetworkMode: "none", + SeccompProfile: "default-hardened", + NoNewPrivileges: false, // Violation - must stay false + }, + MaxWorkers: 1, + } + // Note: Don't call ApplySecurityDefaults() as it would reset NoNewPrivileges to true + + err := cfg.Validate() + if err == nil { + t.Error("expected error for NoNewPrivileges=false in HIPAA mode, got nil") + } else if !contains(err.Error(), "no_new_privileges must be true") { + t.Errorf("expected error about no_new_privileges, got %q", err.Error()) + } +} + +func TestHIPAAValidation_SeccompProfile(t *testing.T) { + cfg := &worker.Config{ + ComplianceMode: "hipaa", + GPUVendor: "none", + Sandbox: worker.SandboxConfig{ + NetworkMode: "none", + SeccompProfile: "", // Empty - violation + NoNewPrivileges: true, + }, + MaxWorkers: 1, + } + + err := cfg.Validate() + if err == nil { + t.Error("expected error for empty SeccompProfile in HIPAA mode, got nil") + } else if !contains(err.Error(), "seccomp_profile must be non-empty") { + t.Errorf("expected error about seccomp_profile, got %q", err.Error()) + } +} + +func TestHIPAAValidation_InlineCredentials(t *testing.T) { + tests := []struct { + name string + setupFunc func(*worker.Config) + errContains string + }{ + { + name: "inline redis_password rejected", + setupFunc: func(cfg *worker.Config) { + cfg.RedisPassword = "plaintext_password" + }, + errContains: "redis_password must use ${ENV_VAR}", + }, + { + name: "env reference redis_password accepted", + setupFunc: func(cfg *worker.Config) { + cfg.RedisPassword = "${REDIS_PASSWORD}" + }, + errContains: "", // No error expected + }, + { + name: "inline ssh_key rejected", + setupFunc: func(cfg *worker.Config) { + cfg.SSHKey = "/path/to/key" + }, + errContains: "ssh_key must use ${ENV_VAR}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &worker.Config{ + ComplianceMode: "hipaa", + GPUVendor: "none", + Sandbox: worker.SandboxConfig{ + NetworkMode: "none", + SeccompProfile: "default-hardened", + NoNewPrivileges: true, + }, + MaxWorkers: 1, + } + tt.setupFunc(cfg) + cfg.Sandbox.ApplySecurityDefaults() + + err := cfg.Validate() + if tt.errContains != "" { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errContains) + } else if !contains(err.Error(), tt.errContains) { + t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error()) + } + } else { + // For env reference test, we expect validation to pass + // (but may fail on other checks like empty env var) + if err != nil && contains(err.Error(), "must use ${ENV_VAR}") { + t.Errorf("unexpected credential error: %v", err) + } + } + }) + } +} + +func TestPHIDenylist_Validation(t *testing.T) { + tests := []struct { + name string + allowedSecrets []string + wantErr bool + errContains string + }{ + { + name: "normal secrets allowed", + allowedSecrets: []string{"HF_TOKEN", "WANDB_API_KEY", "AWS_ACCESS_KEY"}, + wantErr: false, + }, + { + name: "patient_id rejected", + allowedSecrets: []string{"PATIENT_ID"}, + wantErr: true, + errContains: "PHI field", + }, + { + name: "ssn rejected", + allowedSecrets: []string{"SSN"}, + wantErr: true, + errContains: "PHI field", + }, + { + name: "medical_record_number rejected", + allowedSecrets: []string{"MEDICAL_RECORD_NUMBER"}, + wantErr: true, + errContains: "PHI field", + }, + { + name: "diagnosis_code rejected", + allowedSecrets: []string{"DIAGNOSIS_CODE"}, + wantErr: true, + errContains: "PHI field", + }, + { + name: "mixed secrets with phi rejected", + allowedSecrets: []string{"HF_TOKEN", "PATIENT_SSN", "API_KEY"}, + wantErr: true, + errContains: "PHI field", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &worker.Config{ + ComplianceMode: "hipaa", + GPUVendor: "none", + Sandbox: worker.SandboxConfig{ + NetworkMode: "none", + SeccompProfile: "default-hardened", + NoNewPrivileges: true, + AllowedSecrets: tt.allowedSecrets, + }, + MaxWorkers: 1, + } + cfg.Sandbox.ApplySecurityDefaults() + + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errContains) + } else if !contains(err.Error(), tt.errContains) { + t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func TestArtifactIngestionCaps(t *testing.T) { + tests := []struct { + name string + maxFiles int + maxBytes int64 + createFiles int + fileSize int64 + wantErr bool + errContains string + }{ + { + name: "within file cap", + maxFiles: 10, + maxBytes: 10000, + createFiles: 5, + fileSize: 100, + wantErr: false, + }, + { + name: "exceeds file cap", + maxFiles: 3, + maxBytes: 10000, + createFiles: 5, + fileSize: 100, + wantErr: true, + errContains: "file count cap exceeded", + }, + { + name: "exceeds size cap", + maxFiles: 100, + maxBytes: 500, + createFiles: 10, + fileSize: 100, + wantErr: true, + errContains: "total size cap exceeded", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temp directory with test files + runDir := t.TempDir() + for i := 0; i < tt.createFiles; i++ { + f := createTestFile(t, runDir, i, tt.fileSize) + t.Logf("Created: %s", f) + } + + caps := &worker.SandboxConfig{ + MaxArtifactFiles: tt.maxFiles, + MaxArtifactTotalBytes: tt.maxBytes, + } + + _, err := worker.ScanArtifacts(runDir, false, caps) + if tt.wantErr { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errContains) + } else if !contains(err.Error(), tt.errContains) { + t.Errorf("expected error containing %q, got %q", tt.errContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func TestConfigHash_Computation(t *testing.T) { + cfg := &worker.Config{ + Host: "localhost", + Port: 22, + MaxWorkers: 4, + GPUVendor: "nvidia", + Sandbox: worker.SandboxConfig{ + NetworkMode: "none", + SeccompProfile: "default-hardened", + }, + ComplianceMode: "hipaa", + ProvenanceBestEffort: false, + } + + hash1, err := cfg.ComputeResolvedConfigHash() + if err != nil { + t.Fatalf("ComputeResolvedConfigHash failed: %v", err) + } + if hash1 == "" { + t.Error("expected non-empty hash") + } + + // Same config should produce same hash + hash2, err := cfg.ComputeResolvedConfigHash() + if err != nil { + t.Fatalf("ComputeResolvedConfigHash failed: %v", err) + } + if hash1 != hash2 { + t.Error("same config should produce same hash") + } + + // Different config should produce different hash + cfg.MaxWorkers = 8 + hash3, err := cfg.ComputeResolvedConfigHash() + if err != nil { + t.Fatalf("ComputeResolvedConfigHash failed: %v", err) + } + if hash1 == hash3 { + t.Error("different config should produce different hash") + } +} + +// Helper functions +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || containsInternal(s, substr)) +} + +func containsInternal(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func createTestFile(t *testing.T, dir string, index int, size int64) string { + t.Helper() + fname := filepath.Join(dir, fmt.Sprintf("file_%d.txt", index)) + data := make([]byte, size) + if err := os.WriteFile(fname, data, 0640); err != nil { + t.Fatalf("failed to create test file: %v", err) + } + return fname +}