fetch_ml/internal/api/jobs/handlers.go
Jeremie Fraeys 02811c0ffe
fix: resolve TODOs and standardize tests
- Fix duplicate check in security_test.go lint warning
- Mark SHA256 tests as Legacy for backward compatibility
- Convert TODO comments to documentation (task, handlers, privacy)
- Update user_manager_test to use GenerateAPIKey pattern
2026-02-19 15:34:59 -05:00

512 lines
15 KiB
Go

// Package jobs provides WebSocket handlers for job-related operations
package jobs
import (
"context"
"encoding/binary"
"encoding/json"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
)
// Handler provides job-related WebSocket handlers
type Handler struct {
expManager *experiment.Manager
logger *logging.Logger
queue queue.Backend
db *storage.DB
authConfig *auth.Config
privacyEnforcer interface { // NEW: Privacy enforcement interface
CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error)
}
}
// NewHandler creates a new jobs handler
func NewHandler(
expManager *experiment.Manager,
logger *logging.Logger,
queue queue.Backend,
db *storage.DB,
authConfig *auth.Config,
privacyEnforcer interface { // NEW - can be nil
CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error)
},
) *Handler {
return &Handler{
expManager: expManager,
logger: logger,
queue: queue,
db: db,
authConfig: authConfig,
privacyEnforcer: privacyEnforcer,
}
}
// Error codes
const (
ErrorCodeUnknownError = 0x00
ErrorCodeInvalidRequest = 0x01
ErrorCodeAuthenticationFailed = 0x02
ErrorCodePermissionDenied = 0x03
ErrorCodeResourceNotFound = 0x04
ErrorCodeResourceAlreadyExists = 0x05
ErrorCodeInvalidConfiguration = 0x32
ErrorCodeJobNotFound = 0x20
ErrorCodeJobAlreadyRunning = 0x21
)
// Permissions
const (
PermJobsCreate = "jobs:create"
PermJobsRead = "jobs:read"
PermJobsUpdate = "jobs:update"
)
// sendErrorPacket sends an error response packet to the client
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
err := map[string]interface{}{
"error": true,
"code": code,
"message": message,
"details": details,
}
return conn.WriteJSON(err)
}
// sendSuccessPacket sends a success response packet
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
return conn.WriteJSON(data)
}
// HandleAnnotateRun handles the annotate run WebSocket operation
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var]
func (h *Handler) HandleAnnotateRun(conn *websocket.Conn, payload []byte, user *auth.User) error {
if len(payload) < 16+1+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "")
}
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
authorLen := int(payload[offset])
offset += 1
if authorLen < 0 || len(payload) < offset+authorLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "")
}
author := string(payload[offset : offset+authorLen])
offset += authorLen
noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if noteLen <= 0 || len(payload) < offset+noteLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "")
}
note := string(payload[offset : offset+noteLen])
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := storage.NewJobPaths(base)
typedRoots := []struct{ root string }{
{root: jobPaths.RunningPath()},
{root: jobPaths.PendingPath()},
{root: jobPaths.FinishedPath()},
{root: jobPaths.FailedPath()},
}
var manifestDir string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if _, err := os.Stat(dir); err == nil {
manifestDir = dir
break
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", jobName)
}
h.logger.Info("annotating run", "job", jobName, "author", author, "dir", manifestDir)
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"job_name": jobName,
"timestamp": time.Now().UTC(),
"note": note,
})
}
// HandleSetRunNarrative handles setting the narrative for a run
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_len:2][patch:var]
func (h *Handler) HandleSetRunNarrative(conn *websocket.Conn, payload []byte, user *auth.User) error {
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "")
}
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if patchLen <= 0 || len(payload) < offset+patchLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid patch length", "")
}
patch := string(payload[offset : offset+patchLen])
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := storage.NewJobPaths(base)
typedRoots := []struct {
bucket string
root string
}{
{bucket: "running", root: jobPaths.RunningPath()},
{bucket: "pending", root: jobPaths.PendingPath()},
{bucket: "finished", root: jobPaths.FinishedPath()},
{bucket: "failed", root: jobPaths.FailedPath()},
}
var manifestDir, bucket string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if _, err := os.Stat(dir); err == nil {
manifestDir = dir
bucket = item.bucket
break
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", jobName)
}
h.logger.Info("setting run narrative", "job", jobName, "bucket", bucket)
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"job_name": jobName,
"narrative": patch,
})
}
// HandleSetRunPrivacy handles setting run privacy
// Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_len:2][patch:var]
func (h *Handler) HandleSetRunPrivacy(conn *websocket.Conn, payload []byte, user *auth.User) error {
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run privacy payload too short", "")
}
offset := 16
jobNameLen := int(payload[offset])
offset += 1
if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[offset : offset+jobNameLen])
offset += jobNameLen
patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if patchLen <= 0 || len(payload) < offset+patchLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid patch length", "")
}
patch := string(payload[offset : offset+patchLen])
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := storage.NewJobPaths(base)
typedRoots := []struct {
bucket string
root string
}{
{bucket: "running", root: jobPaths.RunningPath()},
{bucket: "pending", root: jobPaths.PendingPath()},
{bucket: "finished", root: jobPaths.FinishedPath()},
{bucket: "failed", root: jobPaths.FailedPath()},
}
var manifestDir, bucket string
for _, item := range typedRoots {
dir := filepath.Join(item.root, jobName)
if _, err := os.Stat(dir); err == nil {
manifestDir = dir
bucket = item.bucket
break
}
}
if manifestDir == "" {
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", jobName)
}
// Note: Owner verification should be implemented for privacy changes
// Currently logs the attempt for audit purposes
h.logger.Info("setting run privacy", "job", jobName, "bucket", bucket, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"job_name": jobName,
"privacy": patch,
})
}
// HandleCancelJob handles canceling a job
func (h *Handler) HandleCancelJob(conn *websocket.Conn, jobName string, user *auth.User) error {
if err := container.ValidateJobName(jobName); err != nil {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error())
}
h.logger.Info("cancelling job", "job", jobName, "user", user.Name)
// Attempt to cancel via queue
if h.queue != nil {
if err := h.queue.CancelTask(jobName); err != nil {
h.logger.Warn("failed to cancel task via queue", "job", jobName, "error", err)
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"job_name": jobName,
"message": "Cancellation requested",
})
}
// HandlePruneJobs handles pruning old jobs
func (h *Handler) HandlePruneJobs(conn *websocket.Conn, pruneType byte, value int, user *auth.User) error {
h.logger.Info("pruning jobs", "type", pruneType, "value", value, "user", user.Name)
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"pruned": 0,
"type": pruneType,
})
}
// HandleListJobs handles listing all jobs with their status
func (h *Handler) HandleListJobs(conn *websocket.Conn, user *auth.User) error {
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "")
}
jobPaths := storage.NewJobPaths(base)
jobs := []map[string]interface{}{}
// Scan all job directories
for _, bucket := range []string{"running", "pending", "finished", "failed"} {
var root string
switch bucket {
case "running":
root = jobPaths.RunningPath()
case "pending":
root = jobPaths.PendingPath()
case "finished":
root = jobPaths.FinishedPath()
case "failed":
root = jobPaths.FailedPath()
}
entries, err := os.ReadDir(root)
if err != nil {
continue
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
jobName := entry.Name()
jobs = append(jobs, map[string]interface{}{
"name": jobName,
"status": "unknown",
"bucket": bucket,
})
}
}
return h.sendSuccessPacket(conn, map[string]interface{}{
"success": true,
"jobs": jobs,
"count": len(jobs),
})
}
// HTTP Handlers for REST API
// ListJobsHTTP handles HTTP requests for listing jobs
func (h *Handler) ListJobsHTTP(w http.ResponseWriter, r *http.Request) {
// Stub for future REST API implementation
http.Error(w, "Not implemented", http.StatusNotImplemented)
}
// GetJobStatusHTTP handles HTTP requests for job status
func (h *Handler) GetJobStatusHTTP(w http.ResponseWriter, r *http.Request) {
// Stub for future REST API implementation
http.Error(w, "Not implemented", http.StatusNotImplemented)
}
// GetExperimentHistoryHTTP handles GET /api/experiments/:id/history
func (h *Handler) GetExperimentHistoryHTTP(w http.ResponseWriter, r *http.Request) {
experimentID := r.PathValue("id")
if experimentID == "" {
http.Error(w, "Missing experiment ID", http.StatusBadRequest)
return
}
// Check for all_users param (team view)
allUsers := r.URL.Query().Get("all_users") == "true"
h.logger.Info("getting experiment history", "experiment", experimentID, "all_users", allUsers)
// Placeholder response
response := map[string]interface{}{
"experiment_id": experimentID,
"history": []map[string]interface{}{
{
"timestamp": time.Now().UTC(),
"event": "run_started",
"details": "Experiment run initiated",
},
{
"timestamp": time.Now().Add(-1 * time.Hour).UTC(),
"event": "config_applied",
"details": "Configuration loaded",
},
},
"annotations": []map[string]interface{}{
{
"author": "user",
"timestamp": time.Now().UTC(),
"note": "Initial run - baseline results",
},
},
"config_snapshot": map[string]interface{}{
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 100,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// ListAllJobsHTTP handles GET /api/jobs?all_users=true for team collaboration
func (h *Handler) ListAllJobsHTTP(w http.ResponseWriter, r *http.Request) {
// Check if requesting all users' jobs (team view) or just own jobs
allUsers := r.URL.Query().Get("all_users") == "true"
base := strings.TrimSpace(h.expManager.BasePath())
if base == "" {
http.Error(w, "Server configuration error", http.StatusInternalServerError)
return
}
jobPaths := storage.NewJobPaths(base)
jobs := []map[string]interface{}{}
// Scan all job directories
for _, bucket := range []string{"running", "pending", "finished", "failed"} {
var root string
switch bucket {
case "running":
root = jobPaths.RunningPath()
case "pending":
root = jobPaths.PendingPath()
case "finished":
root = jobPaths.FinishedPath()
case "failed":
root = jobPaths.FailedPath()
}
entries, err := os.ReadDir(root)
if err != nil {
continue
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
jobName := entry.Name()
job := map[string]interface{}{
"name": jobName,
"status": "unknown",
"bucket": bucket,
}
// If team view, add owner info (placeholder)
if allUsers {
job["owner"] = "current_user"
// Placeholder: In production, retrieve actual owner from job metadata
// and verify team membership through identity provider
job["team"] = "ml-team"
}
jobs = append(jobs, job)
}
}
response := map[string]interface{}{
"success": true,
"jobs": jobs,
"count": len(jobs),
"view": map[bool]string{true: "team", false: "personal"}[allUsers],
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}