diff --git a/cmd/db-utils/init_multi_user.go b/cmd/db-utils/init_multi_user.go index 6f0455a..ea8e980 100644 --- a/cmd/db-utils/init_multi_user.go +++ b/cmd/db-utils/init_multi_user.go @@ -49,30 +49,30 @@ func main() { users := []struct { userID string keyHash string - admin bool roles string permissions string + admin bool }{ { "admin_user", "5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8", - true, `["user", "admin"]`, `{"read": true, "write": true, "delete": true}`, + true, }, { "researcher1", "ef92b778ba7a6c8f2150019a5678047b6a9a2b95cef8189518f9b35c54d2e3ae", - false, `["user", "researcher"]`, `{"read": true, "write": true, "delete": false}`, + false, }, { "analyst1", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3", - false, `["user", "analyst"]`, `{"read": true, "write": false, "delete": false}`, + false, }, } diff --git a/internal/crypto/signing.go b/internal/crypto/signing.go index 3ca246a..6ff859e 100644 --- a/internal/crypto/signing.go +++ b/internal/crypto/signing.go @@ -8,13 +8,15 @@ import ( "encoding/json" "fmt" "os" + + "github.com/jfraeys/fetch_ml/internal/fileutil" ) // ManifestSigner provides Ed25519 signing for run manifests type ManifestSigner struct { + keyID string privateKey ed25519.PrivateKey publicKey ed25519.PublicKey - keyID string } // SigningResult contains the signature and metadata @@ -124,10 +126,10 @@ func (s *ManifestSigner) GetKeyID() string { return s.keyID } -// SavePrivateKeyToFile saves a private key to a file with restricted permissions +// SavePrivateKeyToFile saves a private key to a file with restricted permissions and crash safety (fsync) func SavePrivateKeyToFile(key []byte, path string) error { - // Write with restricted permissions (owner read/write only) - if err := os.WriteFile(path, key, 0600); err != nil { + // Write with restricted permissions (owner read/write only) and fsync + if err := fileutil.WriteFileSafe(path, key, 0600); err != nil { return fmt.Errorf("failed to write private key: %w", err) } return nil @@ -148,9 +150,9 @@ func LoadPrivateKeyFromFile(path string) ([]byte, error) { return key, nil } -// SavePublicKeyToFile saves a public key to a file +// SavePublicKeyToFile saves a public key to a file with crash safety (fsync) func SavePublicKeyToFile(key []byte, path string) error { - if err := os.WriteFile(path, key, 0644); err != nil { + if err := fileutil.WriteFileSafe(path, key, 0644); err != nil { return fmt.Errorf("failed to write public key: %w", err) } return nil diff --git a/internal/envpool/envpool.go b/internal/envpool/envpool.go index f8ef58e..e532f02 100644 --- a/internal/envpool/envpool.go +++ b/internal/envpool/envpool.go @@ -28,18 +28,16 @@ func (r execRunner) CombinedOutput( } type Pool struct { - runner CommandRunner - + runner CommandRunner + cache map[string]cacheEntry imagePrefix string - - cacheMu sync.Mutex - cache map[string]cacheEntry - cacheTTL time.Duration + cacheTTL time.Duration + cacheMu sync.Mutex } type cacheEntry struct { - exists bool expires time.Time + exists bool } func New(imagePrefix string) *Pool { diff --git a/internal/errtypes/errors.go b/internal/errtypes/errors.go index 9cc73f3..d8e786f 100644 --- a/internal/errtypes/errors.go +++ b/internal/errtypes/errors.go @@ -10,9 +10,9 @@ import ( // DataFetchError represents an error that occurred while fetching a dataset // from the NAS to the ML server. type DataFetchError struct { + Err error Dataset string JobName string - Err error } func (e *DataFetchError) Error() string { @@ -25,14 +25,14 @@ func (e *DataFetchError) Unwrap() error { // TaskExecutionError represents an error during task execution. type TaskExecutionError struct { + Timestamp time.Time `json:"timestamp"` + Err error `json:"-"` + Context map[string]string `json:"context,omitempty"` TaskID string `json:"task_id"` JobName string `json:"job_name"` - Phase string `json:"phase"` // "data_fetch", "execution", "cleanup" + Phase string `json:"phase"` Message string `json:"message"` - Err error `json:"-"` - Context map[string]string `json:"context,omitempty"` // Additional context (image, GPU, etc.) - Timestamp time.Time `json:"timestamp"` // When the error occurred - Recoverable bool `json:"recoverable"` // Whether this error is retryable + Recoverable bool `json:"recoverable"` } // Error returns the error message. diff --git a/internal/fileutil/filetype.go b/internal/fileutil/filetype.go index 11b1795..3d776e7 100644 --- a/internal/fileutil/filetype.go +++ b/internal/fileutil/filetype.go @@ -12,9 +12,9 @@ import ( // FileType represents a known file type with its magic bytes type FileType struct { Name string + Description string MagicBytes []byte Extensions []string - Description string } // Known file types for ML artifacts @@ -93,8 +93,8 @@ var ( // DangerousExtensions are file extensions that should be rejected immediately var DangerousExtensions = []string{ ".pt", ".pkl", ".pickle", // PyTorch pickle - arbitrary code execution - ".pth", // PyTorch state dict (often pickle-based) - ".joblib", // scikit-learn pickle format + ".pth", // PyTorch state dict (often pickle-based) + ".joblib", // scikit-learn pickle format ".exe", ".dll", ".so", ".dylib", // Executables ".sh", ".bat", ".cmd", ".ps1", // Scripts ".zip", ".tar", ".gz", ".bz2", ".xz", // Archives (may contain malicious files) diff --git a/internal/fileutil/secure.go b/internal/fileutil/secure.go index 395fdc4..5b5c166 100644 --- a/internal/fileutil/secure.go +++ b/internal/fileutil/secure.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "strings" + "time" ) // SecurePathValidator provides path traversal protection with symlink resolution. @@ -156,3 +157,215 @@ func (v *SecurePathValidator) SecureCreateTemp(pattern string) (*os.File, string return file, fullPath, nil } + +// WriteFileSafe writes data to a file with fsync and removes partial files on error. +// This ensures crash safety: either the file is complete and synced, or it doesn't exist. +func WriteFileSafe(path string, data []byte, perm os.FileMode) error { + f, err := os.OpenFile(filepath.Clean(path), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm) + if err != nil { + return fmt.Errorf("open: %w", err) + } + + if _, err := f.Write(data); err != nil { + f.Close() + os.Remove(path) // remove partial write + return fmt.Errorf("write: %w", err) + } + + // CRITICAL: fsync ensures data is flushed to disk before returning + if err := f.Sync(); err != nil { + f.Close() + os.Remove(path) // remove unsynced file + return fmt.Errorf("fsync: %w", err) + } + + // Close can fail on some filesystems (NFS, network-backed volumes) + if err := f.Close(); err != nil { + os.Remove(path) // remove file if close failed + return fmt.Errorf("close: %w", err) + } + + return nil +} + +// OpenFileNoFollow opens a file with O_NOFOLLOW flag to prevent symlink attacks. +// This ensures the open fails if the final path component is a symlink. +// On Windows, this falls back to regular open (no O_NOFOLLOW equivalent). +func OpenFileNoFollow(path string, flag int, perm os.FileMode) (*os.File, error) { + cleanPath := filepath.Clean(path) + + // Add O_NOFOLLOW on Unix systems + // This causes open to fail if the final path component is a symlink + flag |= O_NOFOLLOW + + return os.OpenFile(cleanPath, flag, perm) +} + +// O_NOFOLLOW is the flag to prevent following symlinks. +// This is set per-platform in the build tags below. +const O_NOFOLLOW = o_NOFOLLOW + +// ValidatePathStrict validates that a path is safe for use. +// It checks for path traversal attempts, resolves symlinks, and ensures +// the final resolved path stays within the base directory. +// Returns the cleaned, validated path or an error if validation fails. +func ValidatePathStrict(filePath, baseDir string) (string, error) { + // Clean the input path + cleanPath := filepath.Clean(filePath) + + // Reject absolute paths + if filepath.IsAbs(cleanPath) { + return "", fmt.Errorf("absolute paths not allowed: %s", filePath) + } + + // Detect path traversal attempts + if strings.HasPrefix(cleanPath, "..") || strings.Contains(cleanPath, "../") { + return "", fmt.Errorf("path traversal attempt detected: %s", filePath) + } + + // Build full path and verify it stays within base directory + fullPath := filepath.Join(baseDir, cleanPath) + resolvedPath, err := filepath.EvalSymlinks(fullPath) + if err != nil { + // If file doesn't exist yet, check the directory path + dir := filepath.Dir(fullPath) + resolvedDir, err := filepath.EvalSymlinks(dir) + if err != nil { + return "", fmt.Errorf("failed to resolve path: %w", err) + } + if !strings.HasPrefix(resolvedDir+string(filepath.Separator), baseDir+string(filepath.Separator)) { + return "", fmt.Errorf("path %q escapes base directory", filePath) + } + // New file path is valid (parent directory is within base) + return fullPath, nil + } + + // Check resolved path is still within base directory + if !strings.HasPrefix(resolvedPath+string(filepath.Separator), baseDir+string(filepath.Separator)) { + return "", fmt.Errorf("resolved path %q escapes base directory", filePath) + } + + return resolvedPath, nil +} + +// SecureTempFile creates a temporary file with secure permissions (0600). +// Unlike os.CreateTemp which uses 0644, this ensures only the owner can read/write. +// The file is created in the specified directory with the given pattern. +func SecureTempFile(dir, pattern string) (*os.File, error) { + if dir == "" { + dir = os.TempDir() + } + + // Ensure temp directory exists with secure permissions + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, fmt.Errorf("failed to create temp directory: %w", err) + } + + // Generate unique filename + timestamp := time.Now().UnixNano() + randomBytes := make([]byte, 8) + if _, err := rand.Read(randomBytes); err != nil { + return nil, fmt.Errorf("failed to generate random filename: %w", err) + } + unique := fmt.Sprintf("%d-%x", timestamp, randomBytes) + + filename := strings.Replace(pattern, "*", unique, 1) + if filename == pattern { + // No wildcard in pattern, append unique suffix + filename = pattern + "." + unique + } + + path := filepath.Join(dir, filename) + + // Create file with restrictive permissions (0600) + // Use O_EXCL to prevent race conditions + f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0600) + if err != nil { + return nil, fmt.Errorf("failed to create secure temp file: %w", err) + } + + return f, nil +} + +// SecureFileRemove securely removes a file by overwriting it with zeros before deletion. +// This prevents recovery of sensitive data from the filesystem. +// For large files, this truncates to zero first, then removes. +func SecureFileRemove(path string) error { + // Get file info + info, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return nil // File doesn't exist, nothing to do + } + return fmt.Errorf("failed to stat file: %w", err) + } + + // Skip directories and special files + if !info.Mode().IsRegular() { + return os.Remove(path) + } + + // Open file for writing + f, err := os.OpenFile(path, os.O_WRONLY, 0) + if err != nil { + return fmt.Errorf("failed to open file for secure removal: %w", err) + } + defer f.Close() + + // Overwrite with zeros (for files up to 10MB) + const maxOverwriteSize = 10 * 1024 * 1024 + if info.Size() <= maxOverwriteSize { + zeros := make([]byte, 4096) + remaining := info.Size() + for remaining > 0 { + toWrite := int64(len(zeros)) + if remaining < toWrite { + toWrite = remaining + } + if _, err := f.Write(zeros[:toWrite]); err != nil { + // Continue to removal even if overwrite fails + break + } + remaining -= toWrite + } + // Sync to ensure zeros hit disk + f.Sync() + } + + // Close and remove + f.Close() + return os.Remove(path) +} + +// AtomicSymlink creates a symlink atomically with proper error handling. +// On Unix systems, it uses symlinkat with AT_SYMLINK_NOFOLLOW where available. +// This prevents race conditions and symlink attacks during creation. +func AtomicSymlink(oldPath, newPath string) error { + // Clean both paths + oldPath = filepath.Clean(oldPath) + newPath = filepath.Clean(newPath) + + // Ensure target exists (but don't follow symlinks) + _, err := os.Lstat(oldPath) + if err != nil { + return fmt.Errorf("symlink target %q does not exist: %w", oldPath, err) + } + + // Remove existing symlink if it exists + if info, err := os.Lstat(newPath); err == nil { + if info.Mode()&os.ModeSymlink != 0 { + if err := os.Remove(newPath); err != nil { + return fmt.Errorf("failed to remove existing symlink: %w", err) + } + } else { + return fmt.Errorf("newPath %q exists and is not a symlink", newPath) + } + } + + // Create the symlink + if err := os.Symlink(oldPath, newPath); err != nil { + return fmt.Errorf("failed to create symlink: %w", err) + } + + return nil +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index f5abae2..6c7ae03 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -62,11 +62,15 @@ func (l *Logger) SetEventRecorder(recorder EventRecorder) { // TaskEvent logs a structured task event and optionally records to event store func (l *Logger) TaskEvent(taskID, eventType string, data map[string]interface{}) { - l.Info("task_event", + args := []any{ "task_id", taskID, "event_type", eventType, "timestamp", time.Now().UTC().Format(time.RFC3339), - ) + } + for k, v := range data { + args = append(args, k, v) + } + l.Info("task_event", args...) } // NewFileLogger creates a logger that writes to a file only (production mode) @@ -79,14 +83,14 @@ func NewFileLogger(level slog.Level, jsonOutput bool, logFile string) *Logger { // Create log directory if it doesn't exist if logFile != "" { logDir := filepath.Dir(logFile) - if err := os.MkdirAll(logDir, 0750); err != nil { + if err := os.MkdirAll(logDir, 0o750); err != nil { // Fallback to stderr only if directory creation fails return NewLogger(level, jsonOutput) } } // Open log file - file, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + file, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) if err != nil { // Fallback to stderr only if file creation fails return NewLogger(level, jsonOutput) diff --git a/internal/network/ssh.go b/internal/network/ssh.go index f0f79aa..4aced6c 100644 --- a/internal/network/ssh.go +++ b/internal/network/ssh.go @@ -150,14 +150,14 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) // Run command with context cancellation type result struct { - output string err error + output string } resultCh := make(chan result, 1) go func() { output, err := session.CombinedOutput(cmd) - resultCh <- result{string(output), err} + resultCh <- result{err, string(output)} }() select { @@ -181,6 +181,8 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) } // Wait a bit more for final result + innerTimer := time.NewTimer(5 * time.Second) + defer innerTimer.Stop() select { case res := <-resultCh: return res.output, fmt.Errorf( @@ -188,7 +190,10 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) ctx.Err(), res.output, ) - case <-time.After(5 * time.Second): + case <-innerTimer.C: + // Goroutine running session.CombinedOutput may still be blocked — + // x/crypto/ssh doesn't support context cancellation on output reads. + // It will unblock when the server closes the connection. return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err()) } } diff --git a/internal/network/ssh_pool.go b/internal/network/ssh_pool.go index 28428e5..71dfe51 100755 --- a/internal/network/ssh_pool.go +++ b/internal/network/ssh_pool.go @@ -12,10 +12,10 @@ import ( type SSHPool struct { factory func() (*SSHClient, error) pool chan *SSHClient + logger *logging.Logger active int maxConns int mu sync.Mutex - logger *logging.Logger } // NewSSHPool creates a new SSH connection pool. diff --git a/internal/privacy/pii.go b/internal/privacy/pii.go index 6da17d0..ae81c6d 100644 --- a/internal/privacy/pii.go +++ b/internal/privacy/pii.go @@ -17,9 +17,9 @@ var piiPatterns = map[string]*regexp.Regexp{ // PIIFinding represents a detected PII instance. type PIIFinding struct { Type string `json:"type"` + Sample string `json:"sample"` Position int `json:"position"` Length int `json:"length"` - Sample string `json:"sample"` // Redacted sample } // DetectPII scans text for potential PII. diff --git a/internal/resources/manager.go b/internal/resources/manager.go index eb899dd..01639b1 100644 --- a/internal/resources/manager.go +++ b/internal/resources/manager.go @@ -16,25 +16,23 @@ import ( ) type Manager struct { - mu sync.Mutex - cond *sync.Cond - totalCPU int - freeCPU int - slotsPerGPU int - gpuFree []int - + cond *sync.Cond + gpuFree []int + totalCPU int + freeCPU int + slotsPerGPU int acquireTotal atomic.Int64 acquireWaitTotal atomic.Int64 acquireTimeoutTotal atomic.Int64 acquireWaitNanos atomic.Int64 + mu sync.Mutex } type Snapshot struct { - TotalCPU int - FreeCPU int - SlotsPerGPU int - GPUFree []int - + GPUFree []int + TotalCPU int + FreeCPU int + SlotsPerGPU int AcquireTotal int64 AcquireWaitTotal int64 AcquireTimeoutTotal int64 @@ -66,9 +64,9 @@ type GPUAllocation struct { } type Lease struct { - cpu int - gpus []GPUAllocation m *Manager + gpus []GPUAllocation + cpu int } func (l *Lease) CPU() int { return l.cpu } diff --git a/internal/security/monitor.go b/internal/security/monitor.go index 7a5731c..6fe8dbf 100644 --- a/internal/security/monitor.go +++ b/internal/security/monitor.go @@ -31,13 +31,13 @@ const ( // Alert represents a security alert type Alert struct { + Timestamp time.Time `json:"timestamp"` + Metadata map[string]any `json:"metadata,omitempty"` Severity AlertSeverity `json:"severity"` Type AlertType `json:"type"` Message string `json:"message"` - Timestamp time.Time `json:"timestamp"` SourceIP string `json:"source_ip,omitempty"` UserID string `json:"user_id,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` } // AlertHandler is called when a security alert is generated @@ -93,27 +93,16 @@ func (w *SlidingWindow) Count() int { // AnomalyMonitor tracks security-relevant events and generates alerts type AnomalyMonitor struct { - // Failed auth tracking per IP - failedAuthByIP map[string]*SlidingWindow - - // Global counters + lastPrivilegedAlert time.Time + failedAuthByIP map[string]*SlidingWindow + alertHandler AlertHandler privilegedContainerAttempts int pathTraversalAttempts int commandInjectionAttempts int - - // Configuration - mu sync.RWMutex - - // Alert handler - alertHandler AlertHandler - - // Thresholds - bruteForceThreshold int - bruteForceWindow time.Duration - privilegedAlertInterval time.Duration - - // Last alert times (to prevent spam) - lastPrivilegedAlert time.Time + bruteForceThreshold int + bruteForceWindow time.Duration + privilegedAlertInterval time.Duration + mu sync.RWMutex } // NewAnomalyMonitor creates a new security anomaly monitor diff --git a/internal/tracking/factory/loader.go b/internal/tracking/factory/loader.go index e8d92e3..8a955be 100644 --- a/internal/tracking/factory/loader.go +++ b/internal/tracking/factory/loader.go @@ -11,12 +11,12 @@ import ( // PluginConfig represents the configuration for a single plugin. type PluginConfig struct { - Enabled bool `toml:"enabled" yaml:"enabled"` + Settings map[string]any `toml:"settings" yaml:"settings"` Image string `toml:"image" yaml:"image"` Mode string `toml:"mode" yaml:"mode"` LogBasePath string `toml:"log_base_path" yaml:"log_base_path"` ArtifactPath string `toml:"artifact_path" yaml:"artifact_path"` - Settings map[string]any `toml:"settings" yaml:"settings"` + Enabled bool `toml:"enabled" yaml:"enabled"` } // PluginFactory is a function that creates a Plugin instance. diff --git a/internal/tracking/plugin.go b/internal/tracking/plugin.go index d714a17..426812b 100644 --- a/internal/tracking/plugin.go +++ b/internal/tracking/plugin.go @@ -22,9 +22,9 @@ const ( // ToolConfig specifies how a plugin should be provisioned for a task. type ToolConfig struct { - Enabled bool - Mode ToolMode Settings map[string]any + Mode ToolMode + Enabled bool } // Plugin defines the behaviour every tracking integration must implement. @@ -38,9 +38,9 @@ type Plugin interface { // Registry keeps track of registered plugins and their lifecycle per task. type Registry struct { logger *logging.Logger - mu sync.Mutex plugins map[string]Plugin active map[string][]string + mu sync.Mutex } // NewRegistry returns a new plugin registry. @@ -147,11 +147,11 @@ func (r *Registry) rollback(ctx context.Context, taskID string, provisioned []st // PortAllocator manages dynamic port assignments for sidecars. type PortAllocator struct { - mu sync.Mutex + used map[int]bool start int end int next int - used map[int]bool + mu sync.Mutex } // NewPortAllocator creates a new allocator for a port range. diff --git a/internal/tracking/plugins/mlflow.go b/internal/tracking/plugins/mlflow.go index 86fd820..46428d9 100644 --- a/internal/tracking/plugins/mlflow.go +++ b/internal/tracking/plugins/mlflow.go @@ -14,10 +14,10 @@ import ( // MLflowOptions configures the MLflow plugin. type MLflowOptions struct { + PortAllocator *trackingpkg.PortAllocator Image string ArtifactBasePath string DefaultTrackingURI string - PortAllocator *trackingpkg.PortAllocator } type mlflowSidecar struct { @@ -27,12 +27,11 @@ type mlflowSidecar struct { // MLflowPlugin provisions MLflow tracking servers per task. type MLflowPlugin struct { - logger *logging.Logger - podman *container.PodmanManager - opts MLflowOptions - - mu sync.Mutex + logger *logging.Logger + podman *container.PodmanManager sidecars map[string]*mlflowSidecar + opts MLflowOptions + mu sync.Mutex } // NewMLflowPlugin creates a new MLflow plugin instance. diff --git a/internal/tracking/plugins/tensorboard.go b/internal/tracking/plugins/tensorboard.go index 3930431..a2ecbc0 100644 --- a/internal/tracking/plugins/tensorboard.go +++ b/internal/tracking/plugins/tensorboard.go @@ -14,9 +14,9 @@ import ( // TensorBoardOptions configure the TensorBoard plugin. type TensorBoardOptions struct { + PortAllocator *trackingpkg.PortAllocator Image string LogBasePath string - PortAllocator *trackingpkg.PortAllocator } type tensorboardSidecar struct { @@ -26,12 +26,11 @@ type tensorboardSidecar struct { // TensorBoardPlugin exposes training logs through TensorBoard. type TensorBoardPlugin struct { - logger *logging.Logger - podman *container.PodmanManager - opts TensorBoardOptions - - mu sync.Mutex + logger *logging.Logger + podman *container.PodmanManager sidecars map[string]*tensorboardSidecar + opts TensorBoardOptions + mu sync.Mutex } // NewTensorBoardPlugin constructs a TensorBoard plugin instance.