refactor(utilities): update supporting modules for scheduler integration

Update utility modules:
- File utilities with secure file operations
- Environment pool with resource tracking
- Error types with scheduler error categories
- Logging with audit context support
- Network/SSH with connection pooling
- Privacy/PII handling with tenant boundaries
- Resource manager with scheduler allocation
- Security monitor with audit integration
- Tracking plugins (MLflow, TensorBoard) with auth
- Crypto signing with tenant keys
- Database init with multi-user support
This commit is contained in:
Jeremie Fraeys 2026-02-26 12:07:15 -05:00
parent 6866ba9366
commit 4cdb68907e
No known key found for this signature in database
16 changed files with 295 additions and 88 deletions

View file

@ -49,30 +49,30 @@ func main() {
users := []struct {
userID string
keyHash string
admin bool
roles string
permissions string
admin bool
}{
{
"admin_user",
"5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8",
true,
`["user", "admin"]`,
`{"read": true, "write": true, "delete": true}`,
true,
},
{
"researcher1",
"ef92b778ba7a6c8f2150019a5678047b6a9a2b95cef8189518f9b35c54d2e3ae",
false,
`["user", "researcher"]`,
`{"read": true, "write": true, "delete": false}`,
false,
},
{
"analyst1",
"a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3",
false,
`["user", "analyst"]`,
`{"read": true, "write": false, "delete": false}`,
false,
},
}

View file

@ -8,13 +8,15 @@ import (
"encoding/json"
"fmt"
"os"
"github.com/jfraeys/fetch_ml/internal/fileutil"
)
// ManifestSigner provides Ed25519 signing for run manifests
type ManifestSigner struct {
keyID string
privateKey ed25519.PrivateKey
publicKey ed25519.PublicKey
keyID string
}
// SigningResult contains the signature and metadata
@ -124,10 +126,10 @@ func (s *ManifestSigner) GetKeyID() string {
return s.keyID
}
// SavePrivateKeyToFile saves a private key to a file with restricted permissions
// SavePrivateKeyToFile saves a private key to a file with restricted permissions and crash safety (fsync)
func SavePrivateKeyToFile(key []byte, path string) error {
// Write with restricted permissions (owner read/write only)
if err := os.WriteFile(path, key, 0600); err != nil {
// Write with restricted permissions (owner read/write only) and fsync
if err := fileutil.WriteFileSafe(path, key, 0600); err != nil {
return fmt.Errorf("failed to write private key: %w", err)
}
return nil
@ -148,9 +150,9 @@ func LoadPrivateKeyFromFile(path string) ([]byte, error) {
return key, nil
}
// SavePublicKeyToFile saves a public key to a file
// SavePublicKeyToFile saves a public key to a file with crash safety (fsync)
func SavePublicKeyToFile(key []byte, path string) error {
if err := os.WriteFile(path, key, 0644); err != nil {
if err := fileutil.WriteFileSafe(path, key, 0644); err != nil {
return fmt.Errorf("failed to write public key: %w", err)
}
return nil

View file

@ -28,18 +28,16 @@ func (r execRunner) CombinedOutput(
}
type Pool struct {
runner CommandRunner
runner CommandRunner
cache map[string]cacheEntry
imagePrefix string
cacheMu sync.Mutex
cache map[string]cacheEntry
cacheTTL time.Duration
cacheTTL time.Duration
cacheMu sync.Mutex
}
type cacheEntry struct {
exists bool
expires time.Time
exists bool
}
func New(imagePrefix string) *Pool {

View file

@ -10,9 +10,9 @@ import (
// DataFetchError represents an error that occurred while fetching a dataset
// from the NAS to the ML server.
type DataFetchError struct {
Err error
Dataset string
JobName string
Err error
}
func (e *DataFetchError) Error() string {
@ -25,14 +25,14 @@ func (e *DataFetchError) Unwrap() error {
// TaskExecutionError represents an error during task execution.
type TaskExecutionError struct {
Timestamp time.Time `json:"timestamp"`
Err error `json:"-"`
Context map[string]string `json:"context,omitempty"`
TaskID string `json:"task_id"`
JobName string `json:"job_name"`
Phase string `json:"phase"` // "data_fetch", "execution", "cleanup"
Phase string `json:"phase"`
Message string `json:"message"`
Err error `json:"-"`
Context map[string]string `json:"context,omitempty"` // Additional context (image, GPU, etc.)
Timestamp time.Time `json:"timestamp"` // When the error occurred
Recoverable bool `json:"recoverable"` // Whether this error is retryable
Recoverable bool `json:"recoverable"`
}
// Error returns the error message.

View file

@ -12,9 +12,9 @@ import (
// FileType represents a known file type with its magic bytes
type FileType struct {
Name string
Description string
MagicBytes []byte
Extensions []string
Description string
}
// Known file types for ML artifacts
@ -93,8 +93,8 @@ var (
// DangerousExtensions are file extensions that should be rejected immediately
var DangerousExtensions = []string{
".pt", ".pkl", ".pickle", // PyTorch pickle - arbitrary code execution
".pth", // PyTorch state dict (often pickle-based)
".joblib", // scikit-learn pickle format
".pth", // PyTorch state dict (often pickle-based)
".joblib", // scikit-learn pickle format
".exe", ".dll", ".so", ".dylib", // Executables
".sh", ".bat", ".cmd", ".ps1", // Scripts
".zip", ".tar", ".gz", ".bz2", ".xz", // Archives (may contain malicious files)

View file

@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"strings"
"time"
)
// SecurePathValidator provides path traversal protection with symlink resolution.
@ -156,3 +157,215 @@ func (v *SecurePathValidator) SecureCreateTemp(pattern string) (*os.File, string
return file, fullPath, nil
}
// WriteFileSafe writes data to a file with fsync and removes partial files on error.
// This ensures crash safety: either the file is complete and synced, or it doesn't exist.
func WriteFileSafe(path string, data []byte, perm os.FileMode) error {
f, err := os.OpenFile(filepath.Clean(path), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm)
if err != nil {
return fmt.Errorf("open: %w", err)
}
if _, err := f.Write(data); err != nil {
f.Close()
os.Remove(path) // remove partial write
return fmt.Errorf("write: %w", err)
}
// CRITICAL: fsync ensures data is flushed to disk before returning
if err := f.Sync(); err != nil {
f.Close()
os.Remove(path) // remove unsynced file
return fmt.Errorf("fsync: %w", err)
}
// Close can fail on some filesystems (NFS, network-backed volumes)
if err := f.Close(); err != nil {
os.Remove(path) // remove file if close failed
return fmt.Errorf("close: %w", err)
}
return nil
}
// OpenFileNoFollow opens a file with O_NOFOLLOW flag to prevent symlink attacks.
// This ensures the open fails if the final path component is a symlink.
// On Windows, this falls back to regular open (no O_NOFOLLOW equivalent).
func OpenFileNoFollow(path string, flag int, perm os.FileMode) (*os.File, error) {
cleanPath := filepath.Clean(path)
// Add O_NOFOLLOW on Unix systems
// This causes open to fail if the final path component is a symlink
flag |= O_NOFOLLOW
return os.OpenFile(cleanPath, flag, perm)
}
// O_NOFOLLOW is the flag to prevent following symlinks.
// This is set per-platform in the build tags below.
const O_NOFOLLOW = o_NOFOLLOW
// ValidatePathStrict validates that a path is safe for use.
// It checks for path traversal attempts, resolves symlinks, and ensures
// the final resolved path stays within the base directory.
// Returns the cleaned, validated path or an error if validation fails.
func ValidatePathStrict(filePath, baseDir string) (string, error) {
// Clean the input path
cleanPath := filepath.Clean(filePath)
// Reject absolute paths
if filepath.IsAbs(cleanPath) {
return "", fmt.Errorf("absolute paths not allowed: %s", filePath)
}
// Detect path traversal attempts
if strings.HasPrefix(cleanPath, "..") || strings.Contains(cleanPath, "../") {
return "", fmt.Errorf("path traversal attempt detected: %s", filePath)
}
// Build full path and verify it stays within base directory
fullPath := filepath.Join(baseDir, cleanPath)
resolvedPath, err := filepath.EvalSymlinks(fullPath)
if err != nil {
// If file doesn't exist yet, check the directory path
dir := filepath.Dir(fullPath)
resolvedDir, err := filepath.EvalSymlinks(dir)
if err != nil {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
if !strings.HasPrefix(resolvedDir+string(filepath.Separator), baseDir+string(filepath.Separator)) {
return "", fmt.Errorf("path %q escapes base directory", filePath)
}
// New file path is valid (parent directory is within base)
return fullPath, nil
}
// Check resolved path is still within base directory
if !strings.HasPrefix(resolvedPath+string(filepath.Separator), baseDir+string(filepath.Separator)) {
return "", fmt.Errorf("resolved path %q escapes base directory", filePath)
}
return resolvedPath, nil
}
// SecureTempFile creates a temporary file with secure permissions (0600).
// Unlike os.CreateTemp which uses 0644, this ensures only the owner can read/write.
// The file is created in the specified directory with the given pattern.
func SecureTempFile(dir, pattern string) (*os.File, error) {
if dir == "" {
dir = os.TempDir()
}
// Ensure temp directory exists with secure permissions
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
// Generate unique filename
timestamp := time.Now().UnixNano()
randomBytes := make([]byte, 8)
if _, err := rand.Read(randomBytes); err != nil {
return nil, fmt.Errorf("failed to generate random filename: %w", err)
}
unique := fmt.Sprintf("%d-%x", timestamp, randomBytes)
filename := strings.Replace(pattern, "*", unique, 1)
if filename == pattern {
// No wildcard in pattern, append unique suffix
filename = pattern + "." + unique
}
path := filepath.Join(dir, filename)
// Create file with restrictive permissions (0600)
// Use O_EXCL to prevent race conditions
f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0600)
if err != nil {
return nil, fmt.Errorf("failed to create secure temp file: %w", err)
}
return f, nil
}
// SecureFileRemove securely removes a file by overwriting it with zeros before deletion.
// This prevents recovery of sensitive data from the filesystem.
// For large files, this truncates to zero first, then removes.
func SecureFileRemove(path string) error {
// Get file info
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil // File doesn't exist, nothing to do
}
return fmt.Errorf("failed to stat file: %w", err)
}
// Skip directories and special files
if !info.Mode().IsRegular() {
return os.Remove(path)
}
// Open file for writing
f, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return fmt.Errorf("failed to open file for secure removal: %w", err)
}
defer f.Close()
// Overwrite with zeros (for files up to 10MB)
const maxOverwriteSize = 10 * 1024 * 1024
if info.Size() <= maxOverwriteSize {
zeros := make([]byte, 4096)
remaining := info.Size()
for remaining > 0 {
toWrite := int64(len(zeros))
if remaining < toWrite {
toWrite = remaining
}
if _, err := f.Write(zeros[:toWrite]); err != nil {
// Continue to removal even if overwrite fails
break
}
remaining -= toWrite
}
// Sync to ensure zeros hit disk
f.Sync()
}
// Close and remove
f.Close()
return os.Remove(path)
}
// AtomicSymlink creates a symlink atomically with proper error handling.
// On Unix systems, it uses symlinkat with AT_SYMLINK_NOFOLLOW where available.
// This prevents race conditions and symlink attacks during creation.
func AtomicSymlink(oldPath, newPath string) error {
// Clean both paths
oldPath = filepath.Clean(oldPath)
newPath = filepath.Clean(newPath)
// Ensure target exists (but don't follow symlinks)
_, err := os.Lstat(oldPath)
if err != nil {
return fmt.Errorf("symlink target %q does not exist: %w", oldPath, err)
}
// Remove existing symlink if it exists
if info, err := os.Lstat(newPath); err == nil {
if info.Mode()&os.ModeSymlink != 0 {
if err := os.Remove(newPath); err != nil {
return fmt.Errorf("failed to remove existing symlink: %w", err)
}
} else {
return fmt.Errorf("newPath %q exists and is not a symlink", newPath)
}
}
// Create the symlink
if err := os.Symlink(oldPath, newPath); err != nil {
return fmt.Errorf("failed to create symlink: %w", err)
}
return nil
}

View file

@ -62,11 +62,15 @@ func (l *Logger) SetEventRecorder(recorder EventRecorder) {
// TaskEvent logs a structured task event and optionally records to event store
func (l *Logger) TaskEvent(taskID, eventType string, data map[string]interface{}) {
l.Info("task_event",
args := []any{
"task_id", taskID,
"event_type", eventType,
"timestamp", time.Now().UTC().Format(time.RFC3339),
)
}
for k, v := range data {
args = append(args, k, v)
}
l.Info("task_event", args...)
}
// NewFileLogger creates a logger that writes to a file only (production mode)
@ -79,14 +83,14 @@ func NewFileLogger(level slog.Level, jsonOutput bool, logFile string) *Logger {
// Create log directory if it doesn't exist
if logFile != "" {
logDir := filepath.Dir(logFile)
if err := os.MkdirAll(logDir, 0750); err != nil {
if err := os.MkdirAll(logDir, 0o750); err != nil {
// Fallback to stderr only if directory creation fails
return NewLogger(level, jsonOutput)
}
}
// Open log file
file, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
file, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600)
if err != nil {
// Fallback to stderr only if file creation fails
return NewLogger(level, jsonOutput)

View file

@ -150,14 +150,14 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error)
// Run command with context cancellation
type result struct {
output string
err error
output string
}
resultCh := make(chan result, 1)
go func() {
output, err := session.CombinedOutput(cmd)
resultCh <- result{string(output), err}
resultCh <- result{err, string(output)}
}()
select {
@ -181,6 +181,8 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error)
}
// Wait a bit more for final result
innerTimer := time.NewTimer(5 * time.Second)
defer innerTimer.Stop()
select {
case res := <-resultCh:
return res.output, fmt.Errorf(
@ -188,7 +190,10 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error)
ctx.Err(),
res.output,
)
case <-time.After(5 * time.Second):
case <-innerTimer.C:
// Goroutine running session.CombinedOutput may still be blocked —
// x/crypto/ssh doesn't support context cancellation on output reads.
// It will unblock when the server closes the connection.
return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err())
}
}

View file

@ -12,10 +12,10 @@ import (
type SSHPool struct {
factory func() (*SSHClient, error)
pool chan *SSHClient
logger *logging.Logger
active int
maxConns int
mu sync.Mutex
logger *logging.Logger
}
// NewSSHPool creates a new SSH connection pool.

View file

@ -17,9 +17,9 @@ var piiPatterns = map[string]*regexp.Regexp{
// PIIFinding represents a detected PII instance.
type PIIFinding struct {
Type string `json:"type"`
Sample string `json:"sample"`
Position int `json:"position"`
Length int `json:"length"`
Sample string `json:"sample"` // Redacted sample
}
// DetectPII scans text for potential PII.

View file

@ -16,25 +16,23 @@ import (
)
type Manager struct {
mu sync.Mutex
cond *sync.Cond
totalCPU int
freeCPU int
slotsPerGPU int
gpuFree []int
cond *sync.Cond
gpuFree []int
totalCPU int
freeCPU int
slotsPerGPU int
acquireTotal atomic.Int64
acquireWaitTotal atomic.Int64
acquireTimeoutTotal atomic.Int64
acquireWaitNanos atomic.Int64
mu sync.Mutex
}
type Snapshot struct {
TotalCPU int
FreeCPU int
SlotsPerGPU int
GPUFree []int
GPUFree []int
TotalCPU int
FreeCPU int
SlotsPerGPU int
AcquireTotal int64
AcquireWaitTotal int64
AcquireTimeoutTotal int64
@ -66,9 +64,9 @@ type GPUAllocation struct {
}
type Lease struct {
cpu int
gpus []GPUAllocation
m *Manager
gpus []GPUAllocation
cpu int
}
func (l *Lease) CPU() int { return l.cpu }

View file

@ -31,13 +31,13 @@ const (
// Alert represents a security alert
type Alert struct {
Timestamp time.Time `json:"timestamp"`
Metadata map[string]any `json:"metadata,omitempty"`
Severity AlertSeverity `json:"severity"`
Type AlertType `json:"type"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
SourceIP string `json:"source_ip,omitempty"`
UserID string `json:"user_id,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// AlertHandler is called when a security alert is generated
@ -93,27 +93,16 @@ func (w *SlidingWindow) Count() int {
// AnomalyMonitor tracks security-relevant events and generates alerts
type AnomalyMonitor struct {
// Failed auth tracking per IP
failedAuthByIP map[string]*SlidingWindow
// Global counters
lastPrivilegedAlert time.Time
failedAuthByIP map[string]*SlidingWindow
alertHandler AlertHandler
privilegedContainerAttempts int
pathTraversalAttempts int
commandInjectionAttempts int
// Configuration
mu sync.RWMutex
// Alert handler
alertHandler AlertHandler
// Thresholds
bruteForceThreshold int
bruteForceWindow time.Duration
privilegedAlertInterval time.Duration
// Last alert times (to prevent spam)
lastPrivilegedAlert time.Time
bruteForceThreshold int
bruteForceWindow time.Duration
privilegedAlertInterval time.Duration
mu sync.RWMutex
}
// NewAnomalyMonitor creates a new security anomaly monitor

View file

@ -11,12 +11,12 @@ import (
// PluginConfig represents the configuration for a single plugin.
type PluginConfig struct {
Enabled bool `toml:"enabled" yaml:"enabled"`
Settings map[string]any `toml:"settings" yaml:"settings"`
Image string `toml:"image" yaml:"image"`
Mode string `toml:"mode" yaml:"mode"`
LogBasePath string `toml:"log_base_path" yaml:"log_base_path"`
ArtifactPath string `toml:"artifact_path" yaml:"artifact_path"`
Settings map[string]any `toml:"settings" yaml:"settings"`
Enabled bool `toml:"enabled" yaml:"enabled"`
}
// PluginFactory is a function that creates a Plugin instance.

View file

@ -22,9 +22,9 @@ const (
// ToolConfig specifies how a plugin should be provisioned for a task.
type ToolConfig struct {
Enabled bool
Mode ToolMode
Settings map[string]any
Mode ToolMode
Enabled bool
}
// Plugin defines the behaviour every tracking integration must implement.
@ -38,9 +38,9 @@ type Plugin interface {
// Registry keeps track of registered plugins and their lifecycle per task.
type Registry struct {
logger *logging.Logger
mu sync.Mutex
plugins map[string]Plugin
active map[string][]string
mu sync.Mutex
}
// NewRegistry returns a new plugin registry.
@ -147,11 +147,11 @@ func (r *Registry) rollback(ctx context.Context, taskID string, provisioned []st
// PortAllocator manages dynamic port assignments for sidecars.
type PortAllocator struct {
mu sync.Mutex
used map[int]bool
start int
end int
next int
used map[int]bool
mu sync.Mutex
}
// NewPortAllocator creates a new allocator for a port range.

View file

@ -14,10 +14,10 @@ import (
// MLflowOptions configures the MLflow plugin.
type MLflowOptions struct {
PortAllocator *trackingpkg.PortAllocator
Image string
ArtifactBasePath string
DefaultTrackingURI string
PortAllocator *trackingpkg.PortAllocator
}
type mlflowSidecar struct {
@ -27,12 +27,11 @@ type mlflowSidecar struct {
// MLflowPlugin provisions MLflow tracking servers per task.
type MLflowPlugin struct {
logger *logging.Logger
podman *container.PodmanManager
opts MLflowOptions
mu sync.Mutex
logger *logging.Logger
podman *container.PodmanManager
sidecars map[string]*mlflowSidecar
opts MLflowOptions
mu sync.Mutex
}
// NewMLflowPlugin creates a new MLflow plugin instance.

View file

@ -14,9 +14,9 @@ import (
// TensorBoardOptions configure the TensorBoard plugin.
type TensorBoardOptions struct {
PortAllocator *trackingpkg.PortAllocator
Image string
LogBasePath string
PortAllocator *trackingpkg.PortAllocator
}
type tensorboardSidecar struct {
@ -26,12 +26,11 @@ type tensorboardSidecar struct {
// TensorBoardPlugin exposes training logs through TensorBoard.
type TensorBoardPlugin struct {
logger *logging.Logger
podman *container.PodmanManager
opts TensorBoardOptions
mu sync.Mutex
logger *logging.Logger
podman *container.PodmanManager
sidecars map[string]*tensorboardSidecar
opts TensorBoardOptions
mu sync.Mutex
}
// NewTensorBoardPlugin constructs a TensorBoard plugin instance.