Core domain and utility updates: - domain/task.go: Task model with visibility system * Visibility enum: private, lab, institution, open * Group associations for lab-scoped access * CreatedBy tracking for ownership * Sharing metadata with expiry - config/paths.go: Group-scoped data directories and audit log paths - crypto/signing.go: Key management for audit sealing, token signature verification - container/supply_chain.go: Image provenance tracking, vulnerability scanning - fileutil/filetype.go: MIME type detection and security validation - fileutil/secure.go: Protected file permissions, secure deletion - jupyter/: Package and service manager updates - experiment/manager.go: Visibility cascade from experiments to tasks - network/ssh.go: SSH tunneling improvements - queue/: Filesystem queue enhancements
311 lines
8.6 KiB
Go
311 lines
8.6 KiB
Go
// Package filesystem provides a filesystem-based queue implementation
|
|
package filesystem
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/domain"
|
|
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
|
)
|
|
|
|
// validTaskID is an allowlist regex for task IDs.
|
|
// Only alphanumeric, underscore, and hyphen allowed. Max 128 chars.
|
|
// This prevents path traversal attacks (null bytes, slashes, backslashes, etc.)
|
|
var validTaskID = regexp.MustCompile(`^[a-zA-Z0-9_\-]{1,128}$`)
|
|
|
|
// validateTaskID checks if a task ID is valid according to the allowlist.
|
|
func validateTaskID(id string) error {
|
|
if id == "" {
|
|
return errors.New("task ID is required")
|
|
}
|
|
if !validTaskID.MatchString(id) {
|
|
return fmt.Errorf("invalid task ID %q: must match %s", id, validTaskID.String())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// writeTaskFile writes task data with O_NOFOLLOW, fsync, and cleanup on error.
|
|
// This prevents symlink attacks and ensures crash safety.
|
|
func writeTaskFile(path string, data []byte) error {
|
|
// Use O_NOFOLLOW to prevent following symlinks (TOCTOU protection)
|
|
f, err := fileutil.OpenFileNoFollow(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640)
|
|
if err != nil {
|
|
return fmt.Errorf("open (symlink rejected): %w", err)
|
|
}
|
|
|
|
if _, err := f.Write(data); err != nil {
|
|
_ = f.Close()
|
|
_ = os.Remove(path) // remove partial write
|
|
return fmt.Errorf("write: %w", err)
|
|
}
|
|
|
|
// CRITICAL: fsync ensures data is flushed to disk before returning
|
|
if err := f.Sync(); err != nil {
|
|
_ = f.Close()
|
|
_ = os.Remove(path) // remove unsynced file
|
|
return fmt.Errorf("fsync: %w", err)
|
|
}
|
|
|
|
// Close can fail on some filesystems (NFS, network-backed volumes)
|
|
if err := f.Close(); err != nil {
|
|
_ = os.Remove(path) // remove file if close failed
|
|
return fmt.Errorf("close: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Queue implements a filesystem-based task queue
|
|
type Queue struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
root string
|
|
}
|
|
|
|
type queueIndex struct {
|
|
UpdatedAt string `json:"updated_at"`
|
|
Tasks []queueIndexTask `json:"tasks"`
|
|
Version int `json:"version"`
|
|
}
|
|
|
|
type queueIndexTask struct {
|
|
ID string `json:"id"`
|
|
CreatedAt string `json:"created_at"`
|
|
Priority int64 `json:"priority"`
|
|
}
|
|
|
|
// NewQueue creates a new filesystem queue instance
|
|
func NewQueue(root string) (*Queue, error) {
|
|
root = strings.TrimSpace(root)
|
|
if root == "" {
|
|
return nil, fmt.Errorf("filesystem queue root is required")
|
|
}
|
|
root = filepath.Clean(root)
|
|
if err := os.MkdirAll(filepath.Join(root, "pending", "entries"), 0750); err != nil {
|
|
return nil, err
|
|
}
|
|
for _, d := range []string{"running", "finished", "failed"} {
|
|
if err := os.MkdirAll(filepath.Join(root, d), 0750); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
q := &Queue{root: root, ctx: ctx, cancel: cancel}
|
|
_ = q.rebuildIndex()
|
|
return q, nil
|
|
}
|
|
|
|
// Close closes the queue
|
|
func (q *Queue) Close() error {
|
|
q.cancel()
|
|
return nil
|
|
}
|
|
|
|
// AddTask adds a task to the queue
|
|
func (q *Queue) AddTask(task *domain.Task) error {
|
|
if task == nil {
|
|
return errors.New("task is nil")
|
|
}
|
|
if err := validateTaskID(task.ID); err != nil {
|
|
return err
|
|
}
|
|
|
|
pendingDir := filepath.Join(q.root, "pending", "entries")
|
|
taskFile := filepath.Join(pendingDir, task.ID+".json")
|
|
|
|
// SECURITY: Verify resolved path is still inside pendingDir (symlink/traversal check)
|
|
resolvedDir, err := filepath.EvalSymlinks(pendingDir)
|
|
if err != nil {
|
|
return fmt.Errorf("resolve pending dir: %w", err)
|
|
}
|
|
resolvedFile, err := filepath.EvalSymlinks(filepath.Dir(taskFile))
|
|
if err != nil {
|
|
return fmt.Errorf("resolve task dir: %w", err)
|
|
}
|
|
if !strings.HasPrefix(resolvedFile+string(filepath.Separator), resolvedDir+string(filepath.Separator)) {
|
|
return fmt.Errorf("task path %q escapes queue root", taskFile)
|
|
}
|
|
|
|
data, err := json.Marshal(task)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal task: %w", err)
|
|
}
|
|
|
|
// SECURITY: Write with fsync + cleanup on error + O_NOFOLLOW to prevent symlink attacks
|
|
if err := writeTaskFile(taskFile, data); err != nil {
|
|
return fmt.Errorf("failed to write task file: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetTask retrieves a task by ID
|
|
func (q *Queue) GetTask(id string) (*domain.Task, error) {
|
|
if err := validateTaskID(id); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Search in all directories
|
|
for _, dir := range []string{"pending", "running", "finished", "failed"} {
|
|
taskFile := filepath.Join(q.root, dir, "entries", id+".json")
|
|
// #nosec G304 -- path is constructed from validated root and validated task ID
|
|
data, err := os.ReadFile(taskFile)
|
|
if err == nil {
|
|
var task domain.Task
|
|
if err := json.Unmarshal(data, &task); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
|
|
}
|
|
return &task, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("task not found: %s", id)
|
|
}
|
|
|
|
// ListTasks lists all tasks in the queue
|
|
func (q *Queue) ListTasks() ([]*domain.Task, error) {
|
|
var tasks []*domain.Task
|
|
|
|
for _, dir := range []string{"pending", "running", "finished", "failed"} {
|
|
entriesDir := filepath.Join(q.root, dir, "entries")
|
|
entries, err := os.ReadDir(entriesDir)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
|
continue
|
|
}
|
|
|
|
// #nosec G304 -- path is constructed from validated root and directory entry
|
|
data, err := os.ReadFile(filepath.Join(entriesDir, entry.Name()))
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
var task domain.Task
|
|
if err := json.Unmarshal(data, &task); err != nil {
|
|
continue
|
|
}
|
|
tasks = append(tasks, &task)
|
|
}
|
|
}
|
|
|
|
return tasks, nil
|
|
}
|
|
|
|
// CancelTask cancels a task
|
|
func (q *Queue) CancelTask(id string) error {
|
|
if err := validateTaskID(id); err != nil {
|
|
return err
|
|
}
|
|
// Remove from pending if exists
|
|
pendingFile := filepath.Join(q.root, "pending", "entries", id+".json")
|
|
if _, err := os.Stat(pendingFile); err == nil {
|
|
return os.Remove(pendingFile)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateTask updates a task
|
|
func (q *Queue) UpdateTask(task *domain.Task) error {
|
|
if task == nil {
|
|
return errors.New("task is nil")
|
|
}
|
|
if err := validateTaskID(task.ID); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Find current location
|
|
var currentFile string
|
|
for _, dir := range []string{"pending", "running", "finished", "failed"} {
|
|
f := filepath.Join(q.root, dir, "entries", task.ID+".json")
|
|
if _, err := os.Stat(f); err == nil {
|
|
currentFile = f
|
|
break
|
|
}
|
|
}
|
|
|
|
if currentFile == "" {
|
|
return fmt.Errorf("task not found: %s", task.ID)
|
|
}
|
|
|
|
data, err := json.Marshal(task)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal task: %w", err)
|
|
}
|
|
|
|
// SECURITY: Write with O_NOFOLLOW + fsync + cleanup on error
|
|
if err := writeTaskFile(currentFile, data); err != nil {
|
|
return fmt.Errorf("failed to write task file: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// rebuildIndex rebuilds the queue index
|
|
func (q *Queue) rebuildIndex() error {
|
|
// Implementation would rebuild the index file
|
|
return nil
|
|
}
|
|
|
|
// applyPrivacyFilter shapes which fields are returned based on visibility level.
|
|
// Returns a shallow copy with sensitive fields redacted — never mutates stored data.
|
|
func applyPrivacyFilter(task *domain.Task, level domain.VisibilityLevel) *domain.Task {
|
|
out := *task // shallow copy
|
|
switch level {
|
|
case domain.VisibilityOpen:
|
|
// Strip args and PII; keep results for reproducibility
|
|
out.Args = "[redacted]"
|
|
out.Output = redactSensitiveInfo(out.Output)
|
|
out.Metadata = filterMetadata(out.Metadata)
|
|
case domain.VisibilityLab, domain.VisibilityInstitution:
|
|
// Full fields; access is audit-logged at the handler layer
|
|
case domain.VisibilityPrivate:
|
|
// Should never reach here (filtered by SQL), but be defensive
|
|
return &domain.Task{ID: task.ID, Status: "private"}
|
|
}
|
|
return &out
|
|
}
|
|
|
|
// redactSensitiveInfo removes PII from output strings.
|
|
// This is a basic implementation — extend as needed for your data.
|
|
func redactSensitiveInfo(output string) string {
|
|
// Simple redaction: truncate long outputs and mark as redacted
|
|
if len(output) > 100 {
|
|
return output[:100] + "\n[... additional output redacted for privacy ...]"
|
|
}
|
|
return output
|
|
}
|
|
|
|
// filterMetadata removes sensitive fields from metadata.
|
|
// Preserves non-sensitive fields like git commit, environment, etc.
|
|
func filterMetadata(metadata map[string]string) map[string]string {
|
|
if metadata == nil {
|
|
return nil
|
|
}
|
|
// List of sensitive keys to remove
|
|
sensitiveKeys := []string{"api_key", "token", "password", "secret", "auth"}
|
|
filtered := make(map[string]string)
|
|
for k, v := range metadata {
|
|
lowerKey := strings.ToLower(k)
|
|
sensitive := false
|
|
for _, sk := range sensitiveKeys {
|
|
if strings.Contains(lowerKey, sk) {
|
|
sensitive = true
|
|
break
|
|
}
|
|
}
|
|
if !sensitive {
|
|
filtered[k] = v
|
|
}
|
|
}
|
|
return filtered
|
|
}
|