refactor(utilities): update supporting modules for scheduler integration
Update utility modules: - File utilities with secure file operations - Environment pool with resource tracking - Error types with scheduler error categories - Logging with audit context support - Network/SSH with connection pooling - Privacy/PII handling with tenant boundaries - Resource manager with scheduler allocation - Security monitor with audit integration - Tracking plugins (MLflow, TensorBoard) with auth - Crypto signing with tenant keys - Database init with multi-user support
This commit is contained in:
parent
6866ba9366
commit
4cdb68907e
16 changed files with 295 additions and 88 deletions
|
|
@ -49,30 +49,30 @@ func main() {
|
||||||
users := []struct {
|
users := []struct {
|
||||||
userID string
|
userID string
|
||||||
keyHash string
|
keyHash string
|
||||||
admin bool
|
|
||||||
roles string
|
roles string
|
||||||
permissions string
|
permissions string
|
||||||
|
admin bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"admin_user",
|
"admin_user",
|
||||||
"5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8",
|
"5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8",
|
||||||
true,
|
|
||||||
`["user", "admin"]`,
|
`["user", "admin"]`,
|
||||||
`{"read": true, "write": true, "delete": true}`,
|
`{"read": true, "write": true, "delete": true}`,
|
||||||
|
true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"researcher1",
|
"researcher1",
|
||||||
"ef92b778ba7a6c8f2150019a5678047b6a9a2b95cef8189518f9b35c54d2e3ae",
|
"ef92b778ba7a6c8f2150019a5678047b6a9a2b95cef8189518f9b35c54d2e3ae",
|
||||||
false,
|
|
||||||
`["user", "researcher"]`,
|
`["user", "researcher"]`,
|
||||||
`{"read": true, "write": true, "delete": false}`,
|
`{"read": true, "write": true, "delete": false}`,
|
||||||
|
false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"analyst1",
|
"analyst1",
|
||||||
"a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3",
|
"a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3",
|
||||||
false,
|
|
||||||
`["user", "analyst"]`,
|
`["user", "analyst"]`,
|
||||||
`{"read": true, "write": false, "delete": false}`,
|
`{"read": true, "write": false, "delete": false}`,
|
||||||
|
false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,15 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ManifestSigner provides Ed25519 signing for run manifests
|
// ManifestSigner provides Ed25519 signing for run manifests
|
||||||
type ManifestSigner struct {
|
type ManifestSigner struct {
|
||||||
|
keyID string
|
||||||
privateKey ed25519.PrivateKey
|
privateKey ed25519.PrivateKey
|
||||||
publicKey ed25519.PublicKey
|
publicKey ed25519.PublicKey
|
||||||
keyID string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SigningResult contains the signature and metadata
|
// SigningResult contains the signature and metadata
|
||||||
|
|
@ -124,10 +126,10 @@ func (s *ManifestSigner) GetKeyID() string {
|
||||||
return s.keyID
|
return s.keyID
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePrivateKeyToFile saves a private key to a file with restricted permissions
|
// SavePrivateKeyToFile saves a private key to a file with restricted permissions and crash safety (fsync)
|
||||||
func SavePrivateKeyToFile(key []byte, path string) error {
|
func SavePrivateKeyToFile(key []byte, path string) error {
|
||||||
// Write with restricted permissions (owner read/write only)
|
// Write with restricted permissions (owner read/write only) and fsync
|
||||||
if err := os.WriteFile(path, key, 0600); err != nil {
|
if err := fileutil.WriteFileSafe(path, key, 0600); err != nil {
|
||||||
return fmt.Errorf("failed to write private key: %w", err)
|
return fmt.Errorf("failed to write private key: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -148,9 +150,9 @@ func LoadPrivateKeyFromFile(path string) ([]byte, error) {
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePublicKeyToFile saves a public key to a file
|
// SavePublicKeyToFile saves a public key to a file with crash safety (fsync)
|
||||||
func SavePublicKeyToFile(key []byte, path string) error {
|
func SavePublicKeyToFile(key []byte, path string) error {
|
||||||
if err := os.WriteFile(path, key, 0644); err != nil {
|
if err := fileutil.WriteFileSafe(path, key, 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write public key: %w", err)
|
return fmt.Errorf("failed to write public key: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -28,18 +28,16 @@ func (r execRunner) CombinedOutput(
|
||||||
}
|
}
|
||||||
|
|
||||||
type Pool struct {
|
type Pool struct {
|
||||||
runner CommandRunner
|
runner CommandRunner
|
||||||
|
cache map[string]cacheEntry
|
||||||
imagePrefix string
|
imagePrefix string
|
||||||
|
cacheTTL time.Duration
|
||||||
cacheMu sync.Mutex
|
cacheMu sync.Mutex
|
||||||
cache map[string]cacheEntry
|
|
||||||
cacheTTL time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type cacheEntry struct {
|
type cacheEntry struct {
|
||||||
exists bool
|
|
||||||
expires time.Time
|
expires time.Time
|
||||||
|
exists bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(imagePrefix string) *Pool {
|
func New(imagePrefix string) *Pool {
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,9 @@ import (
|
||||||
// DataFetchError represents an error that occurred while fetching a dataset
|
// DataFetchError represents an error that occurred while fetching a dataset
|
||||||
// from the NAS to the ML server.
|
// from the NAS to the ML server.
|
||||||
type DataFetchError struct {
|
type DataFetchError struct {
|
||||||
|
Err error
|
||||||
Dataset string
|
Dataset string
|
||||||
JobName string
|
JobName string
|
||||||
Err error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DataFetchError) Error() string {
|
func (e *DataFetchError) Error() string {
|
||||||
|
|
@ -25,14 +25,14 @@ func (e *DataFetchError) Unwrap() error {
|
||||||
|
|
||||||
// TaskExecutionError represents an error during task execution.
|
// TaskExecutionError represents an error during task execution.
|
||||||
type TaskExecutionError struct {
|
type TaskExecutionError struct {
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Err error `json:"-"`
|
||||||
|
Context map[string]string `json:"context,omitempty"`
|
||||||
TaskID string `json:"task_id"`
|
TaskID string `json:"task_id"`
|
||||||
JobName string `json:"job_name"`
|
JobName string `json:"job_name"`
|
||||||
Phase string `json:"phase"` // "data_fetch", "execution", "cleanup"
|
Phase string `json:"phase"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Err error `json:"-"`
|
Recoverable bool `json:"recoverable"`
|
||||||
Context map[string]string `json:"context,omitempty"` // Additional context (image, GPU, etc.)
|
|
||||||
Timestamp time.Time `json:"timestamp"` // When the error occurred
|
|
||||||
Recoverable bool `json:"recoverable"` // Whether this error is retryable
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error returns the error message.
|
// Error returns the error message.
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ import (
|
||||||
// FileType represents a known file type with its magic bytes
|
// FileType represents a known file type with its magic bytes
|
||||||
type FileType struct {
|
type FileType struct {
|
||||||
Name string
|
Name string
|
||||||
|
Description string
|
||||||
MagicBytes []byte
|
MagicBytes []byte
|
||||||
Extensions []string
|
Extensions []string
|
||||||
Description string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Known file types for ML artifacts
|
// Known file types for ML artifacts
|
||||||
|
|
@ -93,8 +93,8 @@ var (
|
||||||
// DangerousExtensions are file extensions that should be rejected immediately
|
// DangerousExtensions are file extensions that should be rejected immediately
|
||||||
var DangerousExtensions = []string{
|
var DangerousExtensions = []string{
|
||||||
".pt", ".pkl", ".pickle", // PyTorch pickle - arbitrary code execution
|
".pt", ".pkl", ".pickle", // PyTorch pickle - arbitrary code execution
|
||||||
".pth", // PyTorch state dict (often pickle-based)
|
".pth", // PyTorch state dict (often pickle-based)
|
||||||
".joblib", // scikit-learn pickle format
|
".joblib", // scikit-learn pickle format
|
||||||
".exe", ".dll", ".so", ".dylib", // Executables
|
".exe", ".dll", ".so", ".dylib", // Executables
|
||||||
".sh", ".bat", ".cmd", ".ps1", // Scripts
|
".sh", ".bat", ".cmd", ".ps1", // Scripts
|
||||||
".zip", ".tar", ".gz", ".bz2", ".xz", // Archives (may contain malicious files)
|
".zip", ".tar", ".gz", ".bz2", ".xz", // Archives (may contain malicious files)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SecurePathValidator provides path traversal protection with symlink resolution.
|
// SecurePathValidator provides path traversal protection with symlink resolution.
|
||||||
|
|
@ -156,3 +157,215 @@ func (v *SecurePathValidator) SecureCreateTemp(pattern string) (*os.File, string
|
||||||
|
|
||||||
return file, fullPath, nil
|
return file, fullPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteFileSafe writes data to a file with fsync and removes partial files on error.
|
||||||
|
// This ensures crash safety: either the file is complete and synced, or it doesn't exist.
|
||||||
|
func WriteFileSafe(path string, data []byte, perm os.FileMode) error {
|
||||||
|
f, err := os.OpenFile(filepath.Clean(path), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := f.Write(data); err != nil {
|
||||||
|
f.Close()
|
||||||
|
os.Remove(path) // remove partial write
|
||||||
|
return fmt.Errorf("write: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CRITICAL: fsync ensures data is flushed to disk before returning
|
||||||
|
if err := f.Sync(); err != nil {
|
||||||
|
f.Close()
|
||||||
|
os.Remove(path) // remove unsynced file
|
||||||
|
return fmt.Errorf("fsync: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close can fail on some filesystems (NFS, network-backed volumes)
|
||||||
|
if err := f.Close(); err != nil {
|
||||||
|
os.Remove(path) // remove file if close failed
|
||||||
|
return fmt.Errorf("close: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenFileNoFollow opens a file with O_NOFOLLOW flag to prevent symlink attacks.
|
||||||
|
// This ensures the open fails if the final path component is a symlink.
|
||||||
|
// On Windows, this falls back to regular open (no O_NOFOLLOW equivalent).
|
||||||
|
func OpenFileNoFollow(path string, flag int, perm os.FileMode) (*os.File, error) {
|
||||||
|
cleanPath := filepath.Clean(path)
|
||||||
|
|
||||||
|
// Add O_NOFOLLOW on Unix systems
|
||||||
|
// This causes open to fail if the final path component is a symlink
|
||||||
|
flag |= O_NOFOLLOW
|
||||||
|
|
||||||
|
return os.OpenFile(cleanPath, flag, perm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// O_NOFOLLOW is the flag to prevent following symlinks.
|
||||||
|
// This is set per-platform in the build tags below.
|
||||||
|
const O_NOFOLLOW = o_NOFOLLOW
|
||||||
|
|
||||||
|
// ValidatePathStrict validates that a path is safe for use.
|
||||||
|
// It checks for path traversal attempts, resolves symlinks, and ensures
|
||||||
|
// the final resolved path stays within the base directory.
|
||||||
|
// Returns the cleaned, validated path or an error if validation fails.
|
||||||
|
func ValidatePathStrict(filePath, baseDir string) (string, error) {
|
||||||
|
// Clean the input path
|
||||||
|
cleanPath := filepath.Clean(filePath)
|
||||||
|
|
||||||
|
// Reject absolute paths
|
||||||
|
if filepath.IsAbs(cleanPath) {
|
||||||
|
return "", fmt.Errorf("absolute paths not allowed: %s", filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect path traversal attempts
|
||||||
|
if strings.HasPrefix(cleanPath, "..") || strings.Contains(cleanPath, "../") {
|
||||||
|
return "", fmt.Errorf("path traversal attempt detected: %s", filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build full path and verify it stays within base directory
|
||||||
|
fullPath := filepath.Join(baseDir, cleanPath)
|
||||||
|
resolvedPath, err := filepath.EvalSymlinks(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
// If file doesn't exist yet, check the directory path
|
||||||
|
dir := filepath.Dir(fullPath)
|
||||||
|
resolvedDir, err := filepath.EvalSymlinks(dir)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(resolvedDir+string(filepath.Separator), baseDir+string(filepath.Separator)) {
|
||||||
|
return "", fmt.Errorf("path %q escapes base directory", filePath)
|
||||||
|
}
|
||||||
|
// New file path is valid (parent directory is within base)
|
||||||
|
return fullPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check resolved path is still within base directory
|
||||||
|
if !strings.HasPrefix(resolvedPath+string(filepath.Separator), baseDir+string(filepath.Separator)) {
|
||||||
|
return "", fmt.Errorf("resolved path %q escapes base directory", filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolvedPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecureTempFile creates a temporary file with secure permissions (0600).
|
||||||
|
// Unlike os.CreateTemp which uses 0644, this ensures only the owner can read/write.
|
||||||
|
// The file is created in the specified directory with the given pattern.
|
||||||
|
func SecureTempFile(dir, pattern string) (*os.File, error) {
|
||||||
|
if dir == "" {
|
||||||
|
dir = os.TempDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure temp directory exists with secure permissions
|
||||||
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create temp directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate unique filename
|
||||||
|
timestamp := time.Now().UnixNano()
|
||||||
|
randomBytes := make([]byte, 8)
|
||||||
|
if _, err := rand.Read(randomBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate random filename: %w", err)
|
||||||
|
}
|
||||||
|
unique := fmt.Sprintf("%d-%x", timestamp, randomBytes)
|
||||||
|
|
||||||
|
filename := strings.Replace(pattern, "*", unique, 1)
|
||||||
|
if filename == pattern {
|
||||||
|
// No wildcard in pattern, append unique suffix
|
||||||
|
filename = pattern + "." + unique
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(dir, filename)
|
||||||
|
|
||||||
|
// Create file with restrictive permissions (0600)
|
||||||
|
// Use O_EXCL to prevent race conditions
|
||||||
|
f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create secure temp file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecureFileRemove securely removes a file by overwriting it with zeros before deletion.
|
||||||
|
// This prevents recovery of sensitive data from the filesystem.
|
||||||
|
// For large files, this truncates to zero first, then removes.
|
||||||
|
func SecureFileRemove(path string) error {
|
||||||
|
// Get file info
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil // File doesn't exist, nothing to do
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to stat file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip directories and special files
|
||||||
|
if !info.Mode().IsRegular() {
|
||||||
|
return os.Remove(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open file for writing
|
||||||
|
f, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open file for secure removal: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Overwrite with zeros (for files up to 10MB)
|
||||||
|
const maxOverwriteSize = 10 * 1024 * 1024
|
||||||
|
if info.Size() <= maxOverwriteSize {
|
||||||
|
zeros := make([]byte, 4096)
|
||||||
|
remaining := info.Size()
|
||||||
|
for remaining > 0 {
|
||||||
|
toWrite := int64(len(zeros))
|
||||||
|
if remaining < toWrite {
|
||||||
|
toWrite = remaining
|
||||||
|
}
|
||||||
|
if _, err := f.Write(zeros[:toWrite]); err != nil {
|
||||||
|
// Continue to removal even if overwrite fails
|
||||||
|
break
|
||||||
|
}
|
||||||
|
remaining -= toWrite
|
||||||
|
}
|
||||||
|
// Sync to ensure zeros hit disk
|
||||||
|
f.Sync()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close and remove
|
||||||
|
f.Close()
|
||||||
|
return os.Remove(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AtomicSymlink creates a symlink atomically with proper error handling.
|
||||||
|
// On Unix systems, it uses symlinkat with AT_SYMLINK_NOFOLLOW where available.
|
||||||
|
// This prevents race conditions and symlink attacks during creation.
|
||||||
|
func AtomicSymlink(oldPath, newPath string) error {
|
||||||
|
// Clean both paths
|
||||||
|
oldPath = filepath.Clean(oldPath)
|
||||||
|
newPath = filepath.Clean(newPath)
|
||||||
|
|
||||||
|
// Ensure target exists (but don't follow symlinks)
|
||||||
|
_, err := os.Lstat(oldPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("symlink target %q does not exist: %w", oldPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove existing symlink if it exists
|
||||||
|
if info, err := os.Lstat(newPath); err == nil {
|
||||||
|
if info.Mode()&os.ModeSymlink != 0 {
|
||||||
|
if err := os.Remove(newPath); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove existing symlink: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("newPath %q exists and is not a symlink", newPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the symlink
|
||||||
|
if err := os.Symlink(oldPath, newPath); err != nil {
|
||||||
|
return fmt.Errorf("failed to create symlink: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -62,11 +62,15 @@ func (l *Logger) SetEventRecorder(recorder EventRecorder) {
|
||||||
|
|
||||||
// TaskEvent logs a structured task event and optionally records to event store
|
// TaskEvent logs a structured task event and optionally records to event store
|
||||||
func (l *Logger) TaskEvent(taskID, eventType string, data map[string]interface{}) {
|
func (l *Logger) TaskEvent(taskID, eventType string, data map[string]interface{}) {
|
||||||
l.Info("task_event",
|
args := []any{
|
||||||
"task_id", taskID,
|
"task_id", taskID,
|
||||||
"event_type", eventType,
|
"event_type", eventType,
|
||||||
"timestamp", time.Now().UTC().Format(time.RFC3339),
|
"timestamp", time.Now().UTC().Format(time.RFC3339),
|
||||||
)
|
}
|
||||||
|
for k, v := range data {
|
||||||
|
args = append(args, k, v)
|
||||||
|
}
|
||||||
|
l.Info("task_event", args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFileLogger creates a logger that writes to a file only (production mode)
|
// NewFileLogger creates a logger that writes to a file only (production mode)
|
||||||
|
|
@ -79,14 +83,14 @@ func NewFileLogger(level slog.Level, jsonOutput bool, logFile string) *Logger {
|
||||||
// Create log directory if it doesn't exist
|
// Create log directory if it doesn't exist
|
||||||
if logFile != "" {
|
if logFile != "" {
|
||||||
logDir := filepath.Dir(logFile)
|
logDir := filepath.Dir(logFile)
|
||||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
if err := os.MkdirAll(logDir, 0o750); err != nil {
|
||||||
// Fallback to stderr only if directory creation fails
|
// Fallback to stderr only if directory creation fails
|
||||||
return NewLogger(level, jsonOutput)
|
return NewLogger(level, jsonOutput)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open log file
|
// Open log file
|
||||||
file, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
file, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Fallback to stderr only if file creation fails
|
// Fallback to stderr only if file creation fails
|
||||||
return NewLogger(level, jsonOutput)
|
return NewLogger(level, jsonOutput)
|
||||||
|
|
|
||||||
|
|
@ -150,14 +150,14 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error)
|
||||||
|
|
||||||
// Run command with context cancellation
|
// Run command with context cancellation
|
||||||
type result struct {
|
type result struct {
|
||||||
output string
|
|
||||||
err error
|
err error
|
||||||
|
output string
|
||||||
}
|
}
|
||||||
resultCh := make(chan result, 1)
|
resultCh := make(chan result, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
output, err := session.CombinedOutput(cmd)
|
output, err := session.CombinedOutput(cmd)
|
||||||
resultCh <- result{string(output), err}
|
resultCh <- result{err, string(output)}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|
@ -181,6 +181,8 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait a bit more for final result
|
// Wait a bit more for final result
|
||||||
|
innerTimer := time.NewTimer(5 * time.Second)
|
||||||
|
defer innerTimer.Stop()
|
||||||
select {
|
select {
|
||||||
case res := <-resultCh:
|
case res := <-resultCh:
|
||||||
return res.output, fmt.Errorf(
|
return res.output, fmt.Errorf(
|
||||||
|
|
@ -188,7 +190,10 @@ func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error)
|
||||||
ctx.Err(),
|
ctx.Err(),
|
||||||
res.output,
|
res.output,
|
||||||
)
|
)
|
||||||
case <-time.After(5 * time.Second):
|
case <-innerTimer.C:
|
||||||
|
// Goroutine running session.CombinedOutput may still be blocked —
|
||||||
|
// x/crypto/ssh doesn't support context cancellation on output reads.
|
||||||
|
// It will unblock when the server closes the connection.
|
||||||
return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err())
|
return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,10 @@ import (
|
||||||
type SSHPool struct {
|
type SSHPool struct {
|
||||||
factory func() (*SSHClient, error)
|
factory func() (*SSHClient, error)
|
||||||
pool chan *SSHClient
|
pool chan *SSHClient
|
||||||
|
logger *logging.Logger
|
||||||
active int
|
active int
|
||||||
maxConns int
|
maxConns int
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
logger *logging.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSSHPool creates a new SSH connection pool.
|
// NewSSHPool creates a new SSH connection pool.
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ var piiPatterns = map[string]*regexp.Regexp{
|
||||||
// PIIFinding represents a detected PII instance.
|
// PIIFinding represents a detected PII instance.
|
||||||
type PIIFinding struct {
|
type PIIFinding struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
Sample string `json:"sample"`
|
||||||
Position int `json:"position"`
|
Position int `json:"position"`
|
||||||
Length int `json:"length"`
|
Length int `json:"length"`
|
||||||
Sample string `json:"sample"` // Redacted sample
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DetectPII scans text for potential PII.
|
// DetectPII scans text for potential PII.
|
||||||
|
|
|
||||||
|
|
@ -16,25 +16,23 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mu sync.Mutex
|
cond *sync.Cond
|
||||||
cond *sync.Cond
|
gpuFree []int
|
||||||
totalCPU int
|
totalCPU int
|
||||||
freeCPU int
|
freeCPU int
|
||||||
slotsPerGPU int
|
slotsPerGPU int
|
||||||
gpuFree []int
|
|
||||||
|
|
||||||
acquireTotal atomic.Int64
|
acquireTotal atomic.Int64
|
||||||
acquireWaitTotal atomic.Int64
|
acquireWaitTotal atomic.Int64
|
||||||
acquireTimeoutTotal atomic.Int64
|
acquireTimeoutTotal atomic.Int64
|
||||||
acquireWaitNanos atomic.Int64
|
acquireWaitNanos atomic.Int64
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type Snapshot struct {
|
type Snapshot struct {
|
||||||
TotalCPU int
|
GPUFree []int
|
||||||
FreeCPU int
|
TotalCPU int
|
||||||
SlotsPerGPU int
|
FreeCPU int
|
||||||
GPUFree []int
|
SlotsPerGPU int
|
||||||
|
|
||||||
AcquireTotal int64
|
AcquireTotal int64
|
||||||
AcquireWaitTotal int64
|
AcquireWaitTotal int64
|
||||||
AcquireTimeoutTotal int64
|
AcquireTimeoutTotal int64
|
||||||
|
|
@ -66,9 +64,9 @@ type GPUAllocation struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Lease struct {
|
type Lease struct {
|
||||||
cpu int
|
|
||||||
gpus []GPUAllocation
|
|
||||||
m *Manager
|
m *Manager
|
||||||
|
gpus []GPUAllocation
|
||||||
|
cpu int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Lease) CPU() int { return l.cpu }
|
func (l *Lease) CPU() int { return l.cpu }
|
||||||
|
|
|
||||||
|
|
@ -31,13 +31,13 @@ const (
|
||||||
|
|
||||||
// Alert represents a security alert
|
// Alert represents a security alert
|
||||||
type Alert struct {
|
type Alert struct {
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"`
|
||||||
Severity AlertSeverity `json:"severity"`
|
Severity AlertSeverity `json:"severity"`
|
||||||
Type AlertType `json:"type"`
|
Type AlertType `json:"type"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
|
||||||
SourceIP string `json:"source_ip,omitempty"`
|
SourceIP string `json:"source_ip,omitempty"`
|
||||||
UserID string `json:"user_id,omitempty"`
|
UserID string `json:"user_id,omitempty"`
|
||||||
Metadata map[string]any `json:"metadata,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AlertHandler is called when a security alert is generated
|
// AlertHandler is called when a security alert is generated
|
||||||
|
|
@ -93,27 +93,16 @@ func (w *SlidingWindow) Count() int {
|
||||||
|
|
||||||
// AnomalyMonitor tracks security-relevant events and generates alerts
|
// AnomalyMonitor tracks security-relevant events and generates alerts
|
||||||
type AnomalyMonitor struct {
|
type AnomalyMonitor struct {
|
||||||
// Failed auth tracking per IP
|
lastPrivilegedAlert time.Time
|
||||||
failedAuthByIP map[string]*SlidingWindow
|
failedAuthByIP map[string]*SlidingWindow
|
||||||
|
alertHandler AlertHandler
|
||||||
// Global counters
|
|
||||||
privilegedContainerAttempts int
|
privilegedContainerAttempts int
|
||||||
pathTraversalAttempts int
|
pathTraversalAttempts int
|
||||||
commandInjectionAttempts int
|
commandInjectionAttempts int
|
||||||
|
bruteForceThreshold int
|
||||||
// Configuration
|
bruteForceWindow time.Duration
|
||||||
mu sync.RWMutex
|
privilegedAlertInterval time.Duration
|
||||||
|
mu sync.RWMutex
|
||||||
// Alert handler
|
|
||||||
alertHandler AlertHandler
|
|
||||||
|
|
||||||
// Thresholds
|
|
||||||
bruteForceThreshold int
|
|
||||||
bruteForceWindow time.Duration
|
|
||||||
privilegedAlertInterval time.Duration
|
|
||||||
|
|
||||||
// Last alert times (to prevent spam)
|
|
||||||
lastPrivilegedAlert time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAnomalyMonitor creates a new security anomaly monitor
|
// NewAnomalyMonitor creates a new security anomaly monitor
|
||||||
|
|
|
||||||
|
|
@ -11,12 +11,12 @@ import (
|
||||||
|
|
||||||
// PluginConfig represents the configuration for a single plugin.
|
// PluginConfig represents the configuration for a single plugin.
|
||||||
type PluginConfig struct {
|
type PluginConfig struct {
|
||||||
Enabled bool `toml:"enabled" yaml:"enabled"`
|
Settings map[string]any `toml:"settings" yaml:"settings"`
|
||||||
Image string `toml:"image" yaml:"image"`
|
Image string `toml:"image" yaml:"image"`
|
||||||
Mode string `toml:"mode" yaml:"mode"`
|
Mode string `toml:"mode" yaml:"mode"`
|
||||||
LogBasePath string `toml:"log_base_path" yaml:"log_base_path"`
|
LogBasePath string `toml:"log_base_path" yaml:"log_base_path"`
|
||||||
ArtifactPath string `toml:"artifact_path" yaml:"artifact_path"`
|
ArtifactPath string `toml:"artifact_path" yaml:"artifact_path"`
|
||||||
Settings map[string]any `toml:"settings" yaml:"settings"`
|
Enabled bool `toml:"enabled" yaml:"enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PluginFactory is a function that creates a Plugin instance.
|
// PluginFactory is a function that creates a Plugin instance.
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,9 @@ const (
|
||||||
|
|
||||||
// ToolConfig specifies how a plugin should be provisioned for a task.
|
// ToolConfig specifies how a plugin should be provisioned for a task.
|
||||||
type ToolConfig struct {
|
type ToolConfig struct {
|
||||||
Enabled bool
|
|
||||||
Mode ToolMode
|
|
||||||
Settings map[string]any
|
Settings map[string]any
|
||||||
|
Mode ToolMode
|
||||||
|
Enabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Plugin defines the behaviour every tracking integration must implement.
|
// Plugin defines the behaviour every tracking integration must implement.
|
||||||
|
|
@ -38,9 +38,9 @@ type Plugin interface {
|
||||||
// Registry keeps track of registered plugins and their lifecycle per task.
|
// Registry keeps track of registered plugins and their lifecycle per task.
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
logger *logging.Logger
|
logger *logging.Logger
|
||||||
mu sync.Mutex
|
|
||||||
plugins map[string]Plugin
|
plugins map[string]Plugin
|
||||||
active map[string][]string
|
active map[string][]string
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRegistry returns a new plugin registry.
|
// NewRegistry returns a new plugin registry.
|
||||||
|
|
@ -147,11 +147,11 @@ func (r *Registry) rollback(ctx context.Context, taskID string, provisioned []st
|
||||||
|
|
||||||
// PortAllocator manages dynamic port assignments for sidecars.
|
// PortAllocator manages dynamic port assignments for sidecars.
|
||||||
type PortAllocator struct {
|
type PortAllocator struct {
|
||||||
mu sync.Mutex
|
used map[int]bool
|
||||||
start int
|
start int
|
||||||
end int
|
end int
|
||||||
next int
|
next int
|
||||||
used map[int]bool
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPortAllocator creates a new allocator for a port range.
|
// NewPortAllocator creates a new allocator for a port range.
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,10 @@ import (
|
||||||
|
|
||||||
// MLflowOptions configures the MLflow plugin.
|
// MLflowOptions configures the MLflow plugin.
|
||||||
type MLflowOptions struct {
|
type MLflowOptions struct {
|
||||||
|
PortAllocator *trackingpkg.PortAllocator
|
||||||
Image string
|
Image string
|
||||||
ArtifactBasePath string
|
ArtifactBasePath string
|
||||||
DefaultTrackingURI string
|
DefaultTrackingURI string
|
||||||
PortAllocator *trackingpkg.PortAllocator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mlflowSidecar struct {
|
type mlflowSidecar struct {
|
||||||
|
|
@ -27,12 +27,11 @@ type mlflowSidecar struct {
|
||||||
|
|
||||||
// MLflowPlugin provisions MLflow tracking servers per task.
|
// MLflowPlugin provisions MLflow tracking servers per task.
|
||||||
type MLflowPlugin struct {
|
type MLflowPlugin struct {
|
||||||
logger *logging.Logger
|
logger *logging.Logger
|
||||||
podman *container.PodmanManager
|
podman *container.PodmanManager
|
||||||
opts MLflowOptions
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
sidecars map[string]*mlflowSidecar
|
sidecars map[string]*mlflowSidecar
|
||||||
|
opts MLflowOptions
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMLflowPlugin creates a new MLflow plugin instance.
|
// NewMLflowPlugin creates a new MLflow plugin instance.
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,9 @@ import (
|
||||||
|
|
||||||
// TensorBoardOptions configure the TensorBoard plugin.
|
// TensorBoardOptions configure the TensorBoard plugin.
|
||||||
type TensorBoardOptions struct {
|
type TensorBoardOptions struct {
|
||||||
|
PortAllocator *trackingpkg.PortAllocator
|
||||||
Image string
|
Image string
|
||||||
LogBasePath string
|
LogBasePath string
|
||||||
PortAllocator *trackingpkg.PortAllocator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type tensorboardSidecar struct {
|
type tensorboardSidecar struct {
|
||||||
|
|
@ -26,12 +26,11 @@ type tensorboardSidecar struct {
|
||||||
|
|
||||||
// TensorBoardPlugin exposes training logs through TensorBoard.
|
// TensorBoardPlugin exposes training logs through TensorBoard.
|
||||||
type TensorBoardPlugin struct {
|
type TensorBoardPlugin struct {
|
||||||
logger *logging.Logger
|
logger *logging.Logger
|
||||||
podman *container.PodmanManager
|
podman *container.PodmanManager
|
||||||
opts TensorBoardOptions
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
sidecars map[string]*tensorboardSidecar
|
sidecars map[string]*tensorboardSidecar
|
||||||
|
opts TensorBoardOptions
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTensorBoardPlugin constructs a TensorBoard plugin instance.
|
// NewTensorBoardPlugin constructs a TensorBoard plugin instance.
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue