// Package filesystem provides a filesystem-based queue implementation package filesystem import ( "context" "encoding/json" "errors" "fmt" "os" "path/filepath" "regexp" "strings" "github.com/jfraeys/fetch_ml/internal/domain" "github.com/jfraeys/fetch_ml/internal/fileutil" ) // validTaskID is an allowlist regex for task IDs. // Only alphanumeric, underscore, and hyphen allowed. Max 128 chars. // This prevents path traversal attacks (null bytes, slashes, backslashes, etc.) var validTaskID = regexp.MustCompile(`^[a-zA-Z0-9_\-]{1,128}$`) // validateTaskID checks if a task ID is valid according to the allowlist. func validateTaskID(id string) error { if id == "" { return errors.New("task ID is required") } if !validTaskID.MatchString(id) { return fmt.Errorf("invalid task ID %q: must match %s", id, validTaskID.String()) } return nil } // writeTaskFile writes task data with O_NOFOLLOW, fsync, and cleanup on error. // This prevents symlink attacks and ensures crash safety. func writeTaskFile(path string, data []byte) error { // Use O_NOFOLLOW to prevent following symlinks (TOCTOU protection) f, err := fileutil.OpenFileNoFollow(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0640) if err != nil { return fmt.Errorf("open (symlink rejected): %w", err) } if _, err := f.Write(data); err != nil { _ = f.Close() _ = os.Remove(path) // remove partial write return fmt.Errorf("write: %w", err) } // CRITICAL: fsync ensures data is flushed to disk before returning if err := f.Sync(); err != nil { _ = f.Close() _ = os.Remove(path) // remove unsynced file return fmt.Errorf("fsync: %w", err) } // Close can fail on some filesystems (NFS, network-backed volumes) if err := f.Close(); err != nil { _ = os.Remove(path) // remove file if close failed return fmt.Errorf("close: %w", err) } return nil } // Queue implements a filesystem-based task queue type Queue struct { ctx context.Context cancel context.CancelFunc root string } type queueIndex struct { UpdatedAt string `json:"updated_at"` Tasks []queueIndexTask `json:"tasks"` Version int `json:"version"` } type queueIndexTask struct { ID string `json:"id"` CreatedAt string `json:"created_at"` Priority int64 `json:"priority"` } // NewQueue creates a new filesystem queue instance func NewQueue(root string) (*Queue, error) { root = strings.TrimSpace(root) if root == "" { return nil, fmt.Errorf("filesystem queue root is required") } root = filepath.Clean(root) if err := os.MkdirAll(filepath.Join(root, "pending", "entries"), 0750); err != nil { return nil, err } for _, d := range []string{"running", "finished", "failed"} { if err := os.MkdirAll(filepath.Join(root, d), 0750); err != nil { return nil, err } } ctx, cancel := context.WithCancel(context.Background()) q := &Queue{root: root, ctx: ctx, cancel: cancel} _ = q.rebuildIndex() return q, nil } // Close closes the queue func (q *Queue) Close() error { q.cancel() return nil } // AddTask adds a task to the queue func (q *Queue) AddTask(task *domain.Task) error { if task == nil { return errors.New("task is nil") } if err := validateTaskID(task.ID); err != nil { return err } pendingDir := filepath.Join(q.root, "pending", "entries") taskFile := filepath.Join(pendingDir, task.ID+".json") // SECURITY: Verify resolved path is still inside pendingDir (symlink/traversal check) resolvedDir, err := filepath.EvalSymlinks(pendingDir) if err != nil { return fmt.Errorf("resolve pending dir: %w", err) } resolvedFile, err := filepath.EvalSymlinks(filepath.Dir(taskFile)) if err != nil { return fmt.Errorf("resolve task dir: %w", err) } if !strings.HasPrefix(resolvedFile+string(filepath.Separator), resolvedDir+string(filepath.Separator)) { return fmt.Errorf("task path %q escapes queue root", taskFile) } data, err := json.Marshal(task) if err != nil { return fmt.Errorf("failed to marshal task: %w", err) } // SECURITY: Write with fsync + cleanup on error + O_NOFOLLOW to prevent symlink attacks if err := writeTaskFile(taskFile, data); err != nil { return fmt.Errorf("failed to write task file: %w", err) } return nil } // GetTask retrieves a task by ID func (q *Queue) GetTask(id string) (*domain.Task, error) { if err := validateTaskID(id); err != nil { return nil, err } // Search in all directories for _, dir := range []string{"pending", "running", "finished", "failed"} { taskFile := filepath.Join(q.root, dir, "entries", id+".json") // #nosec G304 -- path is constructed from validated root and validated task ID data, err := os.ReadFile(taskFile) if err == nil { var task domain.Task if err := json.Unmarshal(data, &task); err != nil { return nil, fmt.Errorf("failed to unmarshal task: %w", err) } return &task, nil } } return nil, fmt.Errorf("task not found: %s", id) } // ListTasks lists all tasks in the queue func (q *Queue) ListTasks() ([]*domain.Task, error) { var tasks []*domain.Task for _, dir := range []string{"pending", "running", "finished", "failed"} { entriesDir := filepath.Join(q.root, dir, "entries") entries, err := os.ReadDir(entriesDir) if err != nil { continue } for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") { continue } // #nosec G304 -- path is constructed from validated root and directory entry data, err := os.ReadFile(filepath.Join(entriesDir, entry.Name())) if err != nil { continue } var task domain.Task if err := json.Unmarshal(data, &task); err != nil { continue } tasks = append(tasks, &task) } } return tasks, nil } // CancelTask cancels a task func (q *Queue) CancelTask(id string) error { if err := validateTaskID(id); err != nil { return err } // Remove from pending if exists pendingFile := filepath.Join(q.root, "pending", "entries", id+".json") if _, err := os.Stat(pendingFile); err == nil { return os.Remove(pendingFile) } return nil } // UpdateTask updates a task func (q *Queue) UpdateTask(task *domain.Task) error { if task == nil { return errors.New("task is nil") } if err := validateTaskID(task.ID); err != nil { return err } // Find current location var currentFile string for _, dir := range []string{"pending", "running", "finished", "failed"} { f := filepath.Join(q.root, dir, "entries", task.ID+".json") if _, err := os.Stat(f); err == nil { currentFile = f break } } if currentFile == "" { return fmt.Errorf("task not found: %s", task.ID) } data, err := json.Marshal(task) if err != nil { return fmt.Errorf("failed to marshal task: %w", err) } // SECURITY: Write with O_NOFOLLOW + fsync + cleanup on error if err := writeTaskFile(currentFile, data); err != nil { return fmt.Errorf("failed to write task file: %w", err) } return nil } // rebuildIndex rebuilds the queue index func (q *Queue) rebuildIndex() error { // Implementation would rebuild the index file return nil } // applyPrivacyFilter shapes which fields are returned based on visibility level. // Returns a shallow copy with sensitive fields redacted — never mutates stored data. func applyPrivacyFilter(task *domain.Task, level domain.VisibilityLevel) *domain.Task { out := *task // shallow copy switch level { case domain.VisibilityOpen: // Strip args and PII; keep results for reproducibility out.Args = "[redacted]" out.Output = redactSensitiveInfo(out.Output) out.Metadata = filterMetadata(out.Metadata) case domain.VisibilityLab, domain.VisibilityInstitution: // Full fields; access is audit-logged at the handler layer case domain.VisibilityPrivate: // Should never reach here (filtered by SQL), but be defensive return &domain.Task{ID: task.ID, Status: "private"} } return &out } // redactSensitiveInfo removes PII from output strings. // This is a basic implementation — extend as needed for your data. func redactSensitiveInfo(output string) string { // Simple redaction: truncate long outputs and mark as redacted if len(output) > 100 { return output[:100] + "\n[... additional output redacted for privacy ...]" } return output } // filterMetadata removes sensitive fields from metadata. // Preserves non-sensitive fields like git commit, environment, etc. func filterMetadata(metadata map[string]string) map[string]string { if metadata == nil { return nil } // List of sensitive keys to remove sensitiveKeys := []string{"api_key", "token", "password", "secret", "auth"} filtered := make(map[string]string) for k, v := range metadata { lowerKey := strings.ToLower(k) sensitive := false for _, sk := range sensitiveKeys { if strings.Contains(lowerKey, sk) { sensitive = true break } } if !sensitive { filtered[k] = v } } return filtered }