package api import ( "crypto/sha256" "crypto/tls" "encoding/binary" "encoding/hex" "fmt" "math" "net/http" "net/url" "strings" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/queue" "golang.org/x/crypto/acme/autocert" ) // Opcodes for binary WebSocket protocol const ( OpcodeQueueJob = 0x01 OpcodeStatusRequest = 0x02 OpcodeCancelJob = 0x03 OpcodePrune = 0x04 OpcodeLogMetric = 0x0A OpcodeGetExperiment = 0x0B ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { // Allow localhost and homelab origins for development origin := r.Header.Get("Origin") if origin == "" { return true // Allow same-origin requests } // Parse origin URL parsedOrigin, err := url.Parse(origin) if err != nil { return false } // Allow localhost and local network origins host := parsedOrigin.Host return strings.HasSuffix(host, ":8080") || strings.HasSuffix(host, ":8081") || strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") || strings.HasPrefix(host, "192.168.") || strings.HasPrefix(host, "10.") || strings.HasPrefix(host, "172.") }, } type WSHandler struct { authConfig *auth.AuthConfig logger *logging.Logger expManager *experiment.Manager queue *queue.TaskQueue } func NewWSHandler(authConfig *auth.AuthConfig, logger *logging.Logger, expManager *experiment.Manager, taskQueue *queue.TaskQueue) *WSHandler { return &WSHandler{ authConfig: authConfig, logger: logger, expManager: expManager, queue: taskQueue, } } func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Check API key before upgrading WebSocket apiKey := r.Header.Get("X-API-Key") if apiKey == "" { // Also check Authorization header authHeader := r.Header.Get("Authorization") if strings.HasPrefix(authHeader, "Bearer ") { apiKey = strings.TrimPrefix(authHeader, "Bearer ") } } // Validate API key if authentication is enabled if h.authConfig != nil && h.authConfig.Enabled { if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil { h.logger.Warn("websocket authentication failed", "error", err) http.Error(w, "Invalid API key", http.StatusUnauthorized) return } } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { h.logger.Error("websocket upgrade failed", "error", err) return } defer conn.Close() h.logger.Info("websocket connection established", "remote", r.RemoteAddr) for { messageType, message, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { h.logger.Error("websocket read error", "error", err) } break } if messageType != websocket.BinaryMessage { h.logger.Warn("received non-binary message") continue } if err := h.handleMessage(conn, message); err != nil { h.logger.Error("message handling error", "error", err) // Send error response _ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode } } } func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { if len(message) < 1 { return fmt.Errorf("message too short") } opcode := message[0] payload := message[1:] switch opcode { case OpcodeQueueJob: return h.handleQueueJob(conn, payload) case OpcodeStatusRequest: return h.handleStatusRequest(conn, payload) case OpcodeCancelJob: return h.handleCancelJob(conn, payload) case OpcodePrune: return h.handlePrune(conn, payload) case OpcodeLogMetric: return h.handleLogMetric(conn, payload) case OpcodeGetExperiment: return h.handleGetExperiment(conn, payload) default: return fmt.Errorf("unknown opcode: 0x%02x", opcode) } } func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var] if len(payload) < 130 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "") } apiKeyHash := string(payload[:64]) commitID := string(payload[64:128]) priority := int64(payload[128]) jobNameLen := int(payload[129]) if len(payload) < 130+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[130 : 130+jobNameLen]) h.logger.Info("queue job request", "job", jobName, "priority", priority, "commit_id", commitID, ) // Validate API key and get user information user, err := h.authConfig.ValidateAPIKey(apiKeyHash) if err != nil { h.logger.Error("invalid api key", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) } // Check user permissions if !h.authConfig.Enabled || user.HasPermission("jobs:create") { h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name) } else { h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "") } // Create experiment directory and metadata if err := h.expManager.CreateExperiment(commitID); err != nil { h.logger.Error("failed to create experiment directory", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error()) } // Add user info to experiment metadata meta := &experiment.Metadata{ CommitID: commitID, JobName: jobName, User: user.Name, Timestamp: time.Now().Unix(), } if err := h.expManager.WriteMetadata(meta); err != nil { h.logger.Error("failed to save experiment metadata", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error()) } h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name) packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) // Enqueue task if queue is available if h.queue != nil { taskID := uuid.New().String() task := &queue.Task{ ID: taskID, JobName: jobName, Args: "", // TODO: Add args support Status: "queued", Priority: priority, CreatedAt: time.Now(), UserID: user.Name, Username: user.Name, CreatedBy: user.Name, Metadata: map[string]string{ "commit_id": commitID, "user_id": user.Name, "username": user.Name, }, } if err := h.queue.AddTask(task); err != nil { h.logger.Error("failed to enqueue task", "error", err) return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error()) } h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) } else { h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) } packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") } return conn.WriteMessage(websocket.BinaryMessage, packetData) } func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:64] if len(payload) < 64 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "") } apiKeyHash := string(payload[0:64]) h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...") // Validate API key and get user information user, err := h.authConfig.ValidateAPIKey(apiKeyHash) if err != nil { h.logger.Error("invalid api key", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) } // Check user permissions for viewing jobs if !h.authConfig.Enabled || user.HasPermission("jobs:read") { // Continue with status request } else { h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read") return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "") } // Get tasks with user filtering var tasks []*queue.Task if h.queue != nil { allTasks, err := h.queue.GetAllTasks() if err != nil { h.logger.Error("failed to get tasks", "error", err) return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error()) } // Filter tasks based on user permissions for _, task := range allTasks { // If auth is disabled or admin can see all tasks if !h.authConfig.Enabled || user.Admin { tasks = append(tasks, task) continue } // Users can only see their own tasks if task.UserID == user.Name || task.CreatedBy == user.Name { tasks = append(tasks, task) } } } // Build status response with user-specific data status := map[string]interface{}{ "user": map[string]interface{}{ "name": user.Name, "admin": user.Admin, "roles": user.Roles, }, "tasks": map[string]interface{}{ "total": len(tasks), "queued": countTasksByStatus(tasks, "queued"), "running": countTasksByStatus(tasks, "running"), "failed": countTasksByStatus(tasks, "failed"), "completed": countTasksByStatus(tasks, "completed"), }, "queue": tasks, // Include filtered tasks } packet := NewSuccessPacketWithPayload("Status retrieved", status) packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") } return conn.WriteMessage(websocket.BinaryMessage, packetData) } // countTasksByStatus counts tasks by their status func countTasksByStatus(tasks []*queue.Task, status string) int { count := 0 for _, task := range tasks { if task.Status == status { count++ } } return count } func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:64][job_name_len:1][job_name:var] if len(payload) < 65 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "") } // Parse 64-byte hex API key hash apiKeyHash := string(payload[0:64]) jobNameLen := int(payload[64]) if len(payload) < 65+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[65 : 65+jobNameLen]) h.logger.Info("cancel job request", "job", jobName) // Validate API key and get user information user, err := h.authConfig.ValidateAPIKey(apiKeyHash) if err != nil { h.logger.Error("invalid api key", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) } // Check user permissions for canceling jobs if !h.authConfig.Enabled || user.HasPermission("jobs:update") { // Continue with cancel request } else { h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "") } // Find the task and verify ownership if h.queue != nil { task, err := h.queue.GetTaskByName(jobName) if err != nil { h.logger.Error("task not found", "job", jobName, "error", err) return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) } // Check if user can cancel this task (admin or owner) if !h.authConfig.Enabled || user.Admin || task.UserID == user.Name || task.CreatedBy == user.Name { // User can cancel the task } else { h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID) return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "") } // Cancel the task if err := h.queue.CancelTask(task.ID); err != nil { h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) } h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) } else { h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) } packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName)) packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") } return conn.WriteMessage(websocket.BinaryMessage, packetData) } func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:64][prune_type:1][value:4] if len(payload) < 69 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "") } // Parse 64-byte hex API key hash apiKeyHash := string(payload[0:64]) pruneType := payload[64] value := binary.BigEndian.Uint32(payload[65:69]) h.logger.Info("prune request", "type", pruneType, "value", value) // Verify API key if h.authConfig != nil && h.authConfig.Enabled { if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { h.logger.Error("api key verification failed", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) } } // Convert prune parameters var keepCount int var olderThanDays int switch pruneType { case 0: // keep N keepCount = int(value) olderThanDays = 0 case 1: // older than days keepCount = 0 olderThanDays = int(value) default: return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "") } // Perform pruning pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays) if err != nil { h.logger.Error("prune failed", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error()) } h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned) // Send structured success response packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned))) return h.sendResponsePacket(conn, packet) } func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var] if len(payload) < 141 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "") } apiKeyHash := string(payload[:64]) commitID := string(payload[64:128]) step := int(binary.BigEndian.Uint32(payload[128:132])) valueBits := binary.BigEndian.Uint64(payload[132:140]) value := math.Float64frombits(valueBits) nameLen := int(payload[140]) if len(payload) < 141+nameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "") } name := string(payload[141 : 141+nameLen]) // Verify API key if h.authConfig != nil && h.authConfig.Enabled { if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) } } if err := h.expManager.LogMetric(commitID, name, value, step); err != nil { h.logger.Error("failed to log metric", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error()) } return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged")) } func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:64][commit_id:64] if len(payload) < 128 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "") } apiKeyHash := string(payload[:64]) commitID := string(payload[64:128]) // Verify API key if h.authConfig != nil && h.authConfig.Enabled { if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) } } meta, err := h.expManager.ReadMetadata(commitID) if err != nil { return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error()) } metrics, err := h.expManager.GetMetrics(commitID) if err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error()) } response := map[string]interface{}{ "metadata": meta, "metrics": metrics, } return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response)) } // Helper to hash API key for comparison func HashAPIKey(apiKey string) string { hash := sha256.Sum256([]byte(apiKey)) return hex.EncodeToString(hash[:]) } // SetupTLSConfig creates TLS configuration for WebSocket server func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) { var server *http.Server if certFile != "" && keyFile != "" { // Use provided certificates server = &http.Server{ TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, }, }, } } else if host != "" { // Use Let's Encrypt with autocert certManager := &autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: autocert.HostWhitelist(host), Cache: autocert.DirCache("/var/www/.cache"), } server = &http.Server{ TLSConfig: certManager.TLSConfig(), } } return server, nil } // verifyAPIKeyHash verifies the provided hex hash against stored API keys func (h *WSHandler) verifyAPIKeyHash(hexHash string) error { if h.authConfig == nil || !h.authConfig.Enabled { return nil // No auth required } // For now, just check if it's a valid 64-char hex string if len(hexHash) != 64 { return fmt.Errorf("invalid api key hash length") } // Check against stored API keys for username, entry := range h.authConfig.APIKeys { if string(entry.Hash) == hexHash { _ = username // Username found but not needed for verification return nil // Valid API key found } } return fmt.Errorf("invalid api key") } // sendErrorPacket sends an error response packet func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error { packet := NewErrorPacket(errorCode, message, details) return h.sendResponsePacket(conn, packet) } // sendResponsePacket sends a structured response packet func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error { data, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize response packet", "error", err) // Fallback to simple error response return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) } return conn.WriteMessage(websocket.BinaryMessage, data) } // sendErrorResponse removed (unused)