fetch_ml/internal/queue/filesystem/queue.go
Jeremie Fraeys 6866ba9366
refactor(queue): integrate scheduler backend and storage improvements
Update queue and storage systems for scheduler integration:
- Queue backend with scheduler coordination
- Filesystem queue with batch operations
- Deduplication with tenant-aware keys
- Storage layer with audit logging hooks
- Domain models (Task, Events, Errors) with scheduler fields
- Database layer with tenant isolation
- Dataset storage with integrity checks
2026-02-26 12:06:46 -05:00

255 lines
6.7 KiB
Go

// Package filesystem provides a filesystem-based queue implementation
package filesystem
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/fileutil"
)
// validTaskID is an allowlist regex for task IDs.
// Only alphanumeric, underscore, and hyphen allowed. Max 128 chars.
// This prevents path traversal attacks (null bytes, slashes, backslashes, etc.)
var validTaskID = regexp.MustCompile(`^[a-zA-Z0-9_\-]{1,128}$`)
// validateTaskID checks if a task ID is valid according to the allowlist.
func validateTaskID(id string) error {
if id == "" {
return errors.New("task ID is required")
}
if !validTaskID.MatchString(id) {
return fmt.Errorf("invalid task ID %q: must match %s", id, validTaskID.String())
}
return nil
}
// writeTaskFile writes task data with O_NOFOLLOW, fsync, and cleanup on error.
// This prevents symlink attacks and ensures crash safety.
func writeTaskFile(path string, data []byte) error {
// Use O_NOFOLLOW to prevent following symlinks (TOCTOU protection)
f, err := fileutil.OpenFileNoFollow(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640)
if err != nil {
return fmt.Errorf("open (symlink rejected): %w", err)
}
if _, err := f.Write(data); err != nil {
f.Close()
os.Remove(path) // remove partial write
return fmt.Errorf("write: %w", err)
}
// CRITICAL: fsync ensures data is flushed to disk before returning
if err := f.Sync(); err != nil {
f.Close()
os.Remove(path) // remove unsynced file
return fmt.Errorf("fsync: %w", err)
}
// Close can fail on some filesystems (NFS, network-backed volumes)
if err := f.Close(); err != nil {
os.Remove(path) // remove file if close failed
return fmt.Errorf("close: %w", err)
}
return nil
}
// Queue implements a filesystem-based task queue
type Queue struct {
ctx context.Context
cancel context.CancelFunc
root string
}
type queueIndex struct {
UpdatedAt string `json:"updated_at"`
Tasks []queueIndexTask `json:"tasks"`
Version int `json:"version"`
}
type queueIndexTask struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
Priority int64 `json:"priority"`
}
// NewQueue creates a new filesystem queue instance
func NewQueue(root string) (*Queue, error) {
root = strings.TrimSpace(root)
if root == "" {
return nil, fmt.Errorf("filesystem queue root is required")
}
root = filepath.Clean(root)
if err := os.MkdirAll(filepath.Join(root, "pending", "entries"), 0750); err != nil {
return nil, err
}
for _, d := range []string{"running", "finished", "failed"} {
if err := os.MkdirAll(filepath.Join(root, d), 0750); err != nil {
return nil, err
}
}
ctx, cancel := context.WithCancel(context.Background())
q := &Queue{root: root, ctx: ctx, cancel: cancel}
_ = q.rebuildIndex()
return q, nil
}
// Close closes the queue
func (q *Queue) Close() error {
q.cancel()
return nil
}
// AddTask adds a task to the queue
func (q *Queue) AddTask(task *domain.Task) error {
if task == nil {
return errors.New("task is nil")
}
if err := validateTaskID(task.ID); err != nil {
return err
}
pendingDir := filepath.Join(q.root, "pending", "entries")
taskFile := filepath.Join(pendingDir, task.ID+".json")
// SECURITY: Verify resolved path is still inside pendingDir (symlink/traversal check)
resolvedDir, err := filepath.EvalSymlinks(pendingDir)
if err != nil {
return fmt.Errorf("resolve pending dir: %w", err)
}
resolvedFile, err := filepath.EvalSymlinks(filepath.Dir(taskFile))
if err != nil {
return fmt.Errorf("resolve task dir: %w", err)
}
if !strings.HasPrefix(resolvedFile+string(filepath.Separator), resolvedDir+string(filepath.Separator)) {
return fmt.Errorf("task path %q escapes queue root", taskFile)
}
data, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
// SECURITY: Write with fsync + cleanup on error + O_NOFOLLOW to prevent symlink attacks
if err := writeTaskFile(taskFile, data); err != nil {
return fmt.Errorf("failed to write task file: %w", err)
}
return nil
}
// GetTask retrieves a task by ID
func (q *Queue) GetTask(id string) (*domain.Task, error) {
if err := validateTaskID(id); err != nil {
return nil, err
}
// Search in all directories
for _, dir := range []string{"pending", "running", "finished", "failed"} {
taskFile := filepath.Join(q.root, dir, "entries", id+".json")
data, err := os.ReadFile(taskFile)
if err == nil {
var task domain.Task
if err := json.Unmarshal(data, &task); err != nil {
return nil, fmt.Errorf("failed to unmarshal task: %w", err)
}
return &task, nil
}
}
return nil, fmt.Errorf("task not found: %s", id)
}
// ListTasks lists all tasks in the queue
func (q *Queue) ListTasks() ([]*domain.Task, error) {
var tasks []*domain.Task
for _, dir := range []string{"pending", "running", "finished", "failed"} {
entriesDir := filepath.Join(q.root, dir, "entries")
entries, err := os.ReadDir(entriesDir)
if err != nil {
continue
}
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
data, err := os.ReadFile(filepath.Join(entriesDir, entry.Name()))
if err != nil {
continue
}
var task domain.Task
if err := json.Unmarshal(data, &task); err != nil {
continue
}
tasks = append(tasks, &task)
}
}
return tasks, nil
}
// CancelTask cancels a task
func (q *Queue) CancelTask(id string) error {
if err := validateTaskID(id); err != nil {
return err
}
// Remove from pending if exists
pendingFile := filepath.Join(q.root, "pending", "entries", id+".json")
if _, err := os.Stat(pendingFile); err == nil {
return os.Remove(pendingFile)
}
return nil
}
// UpdateTask updates a task
func (q *Queue) UpdateTask(task *domain.Task) error {
if task == nil {
return errors.New("task is nil")
}
if err := validateTaskID(task.ID); err != nil {
return err
}
// Find current location
var currentFile string
for _, dir := range []string{"pending", "running", "finished", "failed"} {
f := filepath.Join(q.root, dir, "entries", task.ID+".json")
if _, err := os.Stat(f); err == nil {
currentFile = f
break
}
}
if currentFile == "" {
return fmt.Errorf("task not found: %s", task.ID)
}
data, err := json.Marshal(task)
if err != nil {
return fmt.Errorf("failed to marshal task: %w", err)
}
// SECURITY: Write with O_NOFOLLOW + fsync + cleanup on error
if err := writeTaskFile(currentFile, data); err != nil {
return fmt.Errorf("failed to write task file: %w", err)
}
return nil
}
// rebuildIndex rebuilds the queue index
func (q *Queue) rebuildIndex() error {
// Implementation would rebuild the index file
return nil
}