fetch_ml/internal/fileutil/secure.go
Jeremie Fraeys 4cdb68907e
refactor(utilities): update supporting modules for scheduler integration
Update utility modules:
- File utilities with secure file operations
- Environment pool with resource tracking
- Error types with scheduler error categories
- Logging with audit context support
- Network/SSH with connection pooling
- Privacy/PII handling with tenant boundaries
- Resource manager with scheduler allocation
- Security monitor with audit integration
- Tracking plugins (MLflow, TensorBoard) with auth
- Crypto signing with tenant keys
- Database init with multi-user support
2026-02-26 12:07:15 -05:00

371 lines
12 KiB
Go

// Package fileutil provides secure file operation utilities to prevent path traversal attacks.
package fileutil
import (
"crypto/rand"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
// SecurePathValidator provides path traversal protection with symlink resolution.
type SecurePathValidator struct {
BasePath string
}
// NewSecurePathValidator creates a new path validator for a base directory.
func NewSecurePathValidator(basePath string) *SecurePathValidator {
return &SecurePathValidator{BasePath: basePath}
}
// ValidatePath ensures resolved path is within base directory.
// It resolves symlinks and returns the canonical absolute path.
func (v *SecurePathValidator) ValidatePath(inputPath string) (string, error) {
if v.BasePath == "" {
return "", fmt.Errorf("base path not set")
}
// Clean the path to remove . and ..
cleaned := filepath.Clean(inputPath)
// Get absolute base path and resolve any symlinks (critical for macOS /tmp -> /private/tmp)
baseAbs, err := filepath.Abs(v.BasePath)
if err != nil {
return "", fmt.Errorf("failed to get absolute base path: %w", err)
}
// Resolve symlinks in base path for accurate comparison
baseResolved, err := filepath.EvalSymlinks(baseAbs)
if err != nil {
// Base path may not exist yet, use as-is
baseResolved = baseAbs
}
// If cleaned is already absolute, check if it's within base
var absPath string
if filepath.IsAbs(cleaned) {
// For absolute paths, try to resolve symlinks
resolvedInput, err := filepath.EvalSymlinks(cleaned)
if err != nil {
// Path doesn't exist - try to resolve parent directories to handle macOS /private prefix
dir := filepath.Dir(cleaned)
resolvedDir, dirErr := filepath.EvalSymlinks(dir)
if dirErr == nil {
// Parent resolved successfully, use resolved parent + base name
base := filepath.Base(cleaned)
resolvedInput = filepath.Join(resolvedDir, base)
} else {
// Can't resolve parent either, use cleaned as-is
resolvedInput = cleaned
}
}
absPath = resolvedInput
} else {
// Join with RESOLVED base path if relative (for consistent handling on macOS)
absPath = filepath.Join(baseResolved, cleaned)
}
// FIRST: Check path boundaries before resolving symlinks
// This catches path traversal attempts even if the path doesn't exist
baseWithSep := baseResolved + string(filepath.Separator)
if !strings.HasPrefix(absPath+string(filepath.Separator), baseWithSep) && absPath != baseResolved {
return "", fmt.Errorf("path escapes base directory: %s (base is %s)", inputPath, baseResolved)
}
// Resolve symlinks - critical for security
resolved, err := filepath.EvalSymlinks(absPath)
if err != nil {
// If the file doesn't exist, we still need to check the directory path
// Try to resolve the parent directory
dir := filepath.Dir(absPath)
resolvedDir, dirErr := filepath.EvalSymlinks(dir)
if dirErr != nil {
// Path doesn't exist and parent can't be resolved - this is ok for new files
// as long as the path itself is within bounds (which we checked above)
return absPath, nil
}
// Reconstruct the path with resolved directory
base := filepath.Base(absPath)
resolved = filepath.Join(resolvedDir, base)
}
// Get absolute resolved path
resolvedAbs, err := filepath.Abs(resolved)
if err != nil {
return "", fmt.Errorf("failed to get absolute resolved path: %w", err)
}
// SECOND: Verify resolved path is still within base (symlink escape check)
if resolvedAbs != baseResolved && !strings.HasPrefix(resolvedAbs+string(filepath.Separator), baseWithSep) {
return "", fmt.Errorf("path escapes base directory: %s (resolved to %s, base is %s)", inputPath, resolvedAbs, baseResolved)
}
return resolvedAbs, nil
}
// SecureFileRead securely reads a file after cleaning the path to prevent path traversal.
func SecureFileRead(path string) ([]byte, error) {
return os.ReadFile(filepath.Clean(path))
}
// SecureFileWrite securely writes a file after cleaning the path to prevent path traversal.
func SecureFileWrite(path string, data []byte, perm os.FileMode) error {
return os.WriteFile(filepath.Clean(path), data, perm)
}
// SecureOpenFile securely opens a file after cleaning the path to prevent path traversal.
func SecureOpenFile(path string, flag int, perm os.FileMode) (*os.File, error) {
return os.OpenFile(filepath.Clean(path), flag, perm)
}
// SecureReadDir reads directory contents with path validation.
func (v *SecurePathValidator) SecureReadDir(dirPath string) ([]os.DirEntry, error) {
validatedPath, err := v.ValidatePath(dirPath)
if err != nil {
return nil, fmt.Errorf("directory path validation failed: %w", err)
}
return os.ReadDir(validatedPath)
}
// SecureCreateTemp creates a temporary file within the base directory.
func (v *SecurePathValidator) SecureCreateTemp(pattern string) (*os.File, string, error) {
validatedPath, err := v.ValidatePath("")
if err != nil {
return nil, "", fmt.Errorf("base directory validation failed: %w", err)
}
// Generate secure random suffix
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return nil, "", fmt.Errorf("failed to generate random bytes: %w", err)
}
randomSuffix := base64.URLEncoding.EncodeToString(randomBytes)
// Create temp file
if pattern == "" {
pattern = "tmp"
}
fileName := fmt.Sprintf("%s_%s", pattern, randomSuffix)
fullPath := filepath.Join(validatedPath, fileName)
file, err := os.Create(fullPath)
if err != nil {
return nil, "", fmt.Errorf("failed to create temp file: %w", err)
}
return file, fullPath, nil
}
// WriteFileSafe writes data to a file with fsync and removes partial files on error.
// This ensures crash safety: either the file is complete and synced, or it doesn't exist.
func WriteFileSafe(path string, data []byte, perm os.FileMode) error {
f, err := os.OpenFile(filepath.Clean(path), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm)
if err != nil {
return fmt.Errorf("open: %w", err)
}
if _, err := f.Write(data); err != nil {
f.Close()
os.Remove(path) // remove partial write
return fmt.Errorf("write: %w", err)
}
// CRITICAL: fsync ensures data is flushed to disk before returning
if err := f.Sync(); err != nil {
f.Close()
os.Remove(path) // remove unsynced file
return fmt.Errorf("fsync: %w", err)
}
// Close can fail on some filesystems (NFS, network-backed volumes)
if err := f.Close(); err != nil {
os.Remove(path) // remove file if close failed
return fmt.Errorf("close: %w", err)
}
return nil
}
// OpenFileNoFollow opens a file with O_NOFOLLOW flag to prevent symlink attacks.
// This ensures the open fails if the final path component is a symlink.
// On Windows, this falls back to regular open (no O_NOFOLLOW equivalent).
func OpenFileNoFollow(path string, flag int, perm os.FileMode) (*os.File, error) {
cleanPath := filepath.Clean(path)
// Add O_NOFOLLOW on Unix systems
// This causes open to fail if the final path component is a symlink
flag |= O_NOFOLLOW
return os.OpenFile(cleanPath, flag, perm)
}
// O_NOFOLLOW is the flag to prevent following symlinks.
// This is set per-platform in the build tags below.
const O_NOFOLLOW = o_NOFOLLOW
// ValidatePathStrict validates that a path is safe for use.
// It checks for path traversal attempts, resolves symlinks, and ensures
// the final resolved path stays within the base directory.
// Returns the cleaned, validated path or an error if validation fails.
func ValidatePathStrict(filePath, baseDir string) (string, error) {
// Clean the input path
cleanPath := filepath.Clean(filePath)
// Reject absolute paths
if filepath.IsAbs(cleanPath) {
return "", fmt.Errorf("absolute paths not allowed: %s", filePath)
}
// Detect path traversal attempts
if strings.HasPrefix(cleanPath, "..") || strings.Contains(cleanPath, "../") {
return "", fmt.Errorf("path traversal attempt detected: %s", filePath)
}
// Build full path and verify it stays within base directory
fullPath := filepath.Join(baseDir, cleanPath)
resolvedPath, err := filepath.EvalSymlinks(fullPath)
if err != nil {
// If file doesn't exist yet, check the directory path
dir := filepath.Dir(fullPath)
resolvedDir, err := filepath.EvalSymlinks(dir)
if err != nil {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
if !strings.HasPrefix(resolvedDir+string(filepath.Separator), baseDir+string(filepath.Separator)) {
return "", fmt.Errorf("path %q escapes base directory", filePath)
}
// New file path is valid (parent directory is within base)
return fullPath, nil
}
// Check resolved path is still within base directory
if !strings.HasPrefix(resolvedPath+string(filepath.Separator), baseDir+string(filepath.Separator)) {
return "", fmt.Errorf("resolved path %q escapes base directory", filePath)
}
return resolvedPath, nil
}
// SecureTempFile creates a temporary file with secure permissions (0600).
// Unlike os.CreateTemp which uses 0644, this ensures only the owner can read/write.
// The file is created in the specified directory with the given pattern.
func SecureTempFile(dir, pattern string) (*os.File, error) {
if dir == "" {
dir = os.TempDir()
}
// Ensure temp directory exists with secure permissions
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
// Generate unique filename
timestamp := time.Now().UnixNano()
randomBytes := make([]byte, 8)
if _, err := rand.Read(randomBytes); err != nil {
return nil, fmt.Errorf("failed to generate random filename: %w", err)
}
unique := fmt.Sprintf("%d-%x", timestamp, randomBytes)
filename := strings.Replace(pattern, "*", unique, 1)
if filename == pattern {
// No wildcard in pattern, append unique suffix
filename = pattern + "." + unique
}
path := filepath.Join(dir, filename)
// Create file with restrictive permissions (0600)
// Use O_EXCL to prevent race conditions
f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0600)
if err != nil {
return nil, fmt.Errorf("failed to create secure temp file: %w", err)
}
return f, nil
}
// SecureFileRemove securely removes a file by overwriting it with zeros before deletion.
// This prevents recovery of sensitive data from the filesystem.
// For large files, this truncates to zero first, then removes.
func SecureFileRemove(path string) error {
// Get file info
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil // File doesn't exist, nothing to do
}
return fmt.Errorf("failed to stat file: %w", err)
}
// Skip directories and special files
if !info.Mode().IsRegular() {
return os.Remove(path)
}
// Open file for writing
f, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return fmt.Errorf("failed to open file for secure removal: %w", err)
}
defer f.Close()
// Overwrite with zeros (for files up to 10MB)
const maxOverwriteSize = 10 * 1024 * 1024
if info.Size() <= maxOverwriteSize {
zeros := make([]byte, 4096)
remaining := info.Size()
for remaining > 0 {
toWrite := int64(len(zeros))
if remaining < toWrite {
toWrite = remaining
}
if _, err := f.Write(zeros[:toWrite]); err != nil {
// Continue to removal even if overwrite fails
break
}
remaining -= toWrite
}
// Sync to ensure zeros hit disk
f.Sync()
}
// Close and remove
f.Close()
return os.Remove(path)
}
// AtomicSymlink creates a symlink atomically with proper error handling.
// On Unix systems, it uses symlinkat with AT_SYMLINK_NOFOLLOW where available.
// This prevents race conditions and symlink attacks during creation.
func AtomicSymlink(oldPath, newPath string) error {
// Clean both paths
oldPath = filepath.Clean(oldPath)
newPath = filepath.Clean(newPath)
// Ensure target exists (but don't follow symlinks)
_, err := os.Lstat(oldPath)
if err != nil {
return fmt.Errorf("symlink target %q does not exist: %w", oldPath, err)
}
// Remove existing symlink if it exists
if info, err := os.Lstat(newPath); err == nil {
if info.Mode()&os.ModeSymlink != 0 {
if err := os.Remove(newPath); err != nil {
return fmt.Errorf("failed to remove existing symlink: %w", err)
}
} else {
return fmt.Errorf("newPath %q exists and is not a symlink", newPath)
}
}
// Create the symlink
if err := os.Symlink(oldPath, newPath); err != nil {
return fmt.Errorf("failed to create symlink: %w", err)
}
return nil
}