fetch_ml/internal/api/ws.go
Jeremie Fraeys cd5640ebd2 Slim and secure: move scripts, clean configs, remove secrets
- Move ci-test.sh and setup.sh to scripts/
- Trim docs/src/zig-cli.md to current structure
- Replace hardcoded secrets with placeholders in configs
- Update .gitignore to block .env*, secrets/, keys, build artifacts
- Slim README.md to reflect current CLI/TUI split
- Add cleanup trap to ci-test.sh
- Ensure no secrets are committed
2025-12-07 13:57:51 -05:00

652 lines
20 KiB
Go

package api
import (
"crypto/tls"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/telemetry"
"golang.org/x/crypto/acme/autocert"
)
// Opcodes for binary WebSocket protocol
const (
OpcodeQueueJob = 0x01
OpcodeStatusRequest = 0x02
OpcodeCancelJob = 0x03
OpcodePrune = 0x04
OpcodeLogMetric = 0x0A
OpcodeGetExperiment = 0x0B
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// Allow localhost and homelab origins for development
origin := r.Header.Get("Origin")
if origin == "" {
return true // Allow same-origin requests
}
// Parse origin URL
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false
}
// Allow localhost and local network origins
host := parsedOrigin.Host
return strings.HasSuffix(host, ":8080") ||
strings.HasPrefix(host, "localhost:") ||
strings.HasPrefix(host, "127.0.0.1:") ||
strings.HasPrefix(host, "192.168.") ||
strings.HasPrefix(host, "10.") ||
strings.HasPrefix(host, "172.")
},
// Performance optimizations
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
EnableCompression: true,
}
// WSHandler handles WebSocket connections for the API.
type WSHandler struct {
authConfig *auth.Config
logger *logging.Logger
expManager *experiment.Manager
queue *queue.TaskQueue
}
// NewWSHandler creates a new WebSocket handler.
func NewWSHandler(
authConfig *auth.Config,
logger *logging.Logger,
expManager *experiment.Manager,
taskQueue *queue.TaskQueue,
) *WSHandler {
return &WSHandler{
authConfig: authConfig,
logger: logger,
expManager: expManager,
queue: taskQueue,
}
}
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check API key before upgrading WebSocket
apiKey := auth.ExtractAPIKeyFromRequest(r)
// Validate API key if authentication is enabled
if h.authConfig != nil && h.authConfig.Enabled {
prefixLen := len(apiKey)
if prefixLen > 8 {
prefixLen = 8
}
h.logger.Info("websocket auth attempt", "api_key_length", len(apiKey), "api_key_prefix", apiKey[:prefixLen])
if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil {
h.logger.Warn("websocket authentication failed", "error", err)
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
h.logger.Info("websocket authentication succeeded")
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("websocket upgrade failed", "error", err)
return
}
defer func() {
_ = conn.Close()
}()
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
h.logger.Error("websocket read error", "error", err)
}
break
}
if messageType != websocket.BinaryMessage {
h.logger.Warn("received non-binary message")
continue
}
if err := h.handleMessage(conn, message); err != nil {
h.logger.Error("message handling error", "error", err)
// Send error response
_ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode
}
}
}
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
if len(message) < 1 {
return fmt.Errorf("message too short")
}
opcode := message[0]
payload := message[1:]
switch opcode {
case OpcodeQueueJob:
return h.handleQueueJob(conn, payload)
case OpcodeStatusRequest:
return h.handleStatusRequest(conn, payload)
case OpcodeCancelJob:
return h.handleCancelJob(conn, payload)
case OpcodePrune:
return h.handlePrune(conn, payload)
case OpcodeLogMetric:
return h.handleLogMetric(conn, payload)
case OpcodeGetExperiment:
return h.handleGetExperiment(conn, payload)
default:
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
}
}
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
if len(payload) < 130 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
priority := int64(payload[128])
jobNameLen := int(payload[129])
if len(payload) < 130+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[130 : 130+jobNameLen])
h.logger.Info("queue job request",
"job", jobName,
"priority", priority,
"commit_id", commitID,
)
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "")
}
// Create experiment directory and metadata (optimized)
if _, err := telemetry.ExecWithMetrics(h.logger, "experiment.create", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.CreateExperiment(commitID)
}); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
}
// Add user info to experiment metadata (deferred for performance)
go func() {
meta := &experiment.Metadata{
CommitID: commitID,
JobName: jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
}
}()
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
// Enqueue task if queue is available
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "",
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
Metadata: map[string]string{
"commit_id": commitID, // Reduced redundant metadata
},
}
if _, err := telemetry.ExecWithMetrics(h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) {
return "", h.queue.AddTask(task)
}); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64]
if len(payload) < 64 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "")
}
apiKeyHash := string(payload[0:64])
h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...")
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions for viewing jobs
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "")
}
// Get tasks with user filtering
var tasks []*queue.Task
if h.queue != nil {
allTasks, err := h.queue.GetAllTasks()
if err != nil {
h.logger.Error("failed to get tasks", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
}
// Filter tasks based on user permissions
for _, task := range allTasks {
// If auth is disabled or admin can see all tasks
if h.authConfig == nil || !h.authConfig.Enabled || user.Admin {
tasks = append(tasks, task)
continue
}
// Users can only see their own tasks
if task.UserID == user.Name || task.CreatedBy == user.Name {
tasks = append(tasks, task)
}
}
}
// Build status response as raw JSON for CLI compatibility
h.logger.Info("building status response")
status := map[string]interface{}{
"user": map[string]interface{}{
"name": user.Name,
"admin": user.Admin,
"roles": user.Roles,
},
"tasks": map[string]interface{}{
"total": len(tasks),
"queued": countTasksByStatus(tasks, "queued"),
"running": countTasksByStatus(tasks, "running"),
"failed": countTasksByStatus(tasks, "failed"),
"completed": countTasksByStatus(tasks, "completed"),
},
"queue": tasks,
}
h.logger.Info("serializing JSON response")
jsonData, err := json.Marshal(status)
if err != nil {
h.logger.Error("failed to marshal JSON", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
h.logger.Info("sending websocket JSON response", "len", len(jsonData))
return conn.WriteMessage(websocket.BinaryMessage, jsonData)
}
// countTasksByStatus counts tasks by their status
func countTasksByStatus(tasks []*queue.Task, status string) int {
count := 0
for _, task := range tasks {
if task.Status == status {
count++
}
}
return count
}
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][job_name_len:1][job_name:var]
if len(payload) < 65 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "")
}
// Parse 64-byte hex API key hash
apiKeyHash := string(payload[0:64])
jobNameLen := int(payload[64])
if len(payload) < 65+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[65 : 65+jobNameLen])
h.logger.Info("cancel job request", "job", jobName)
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions for canceling jobs
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "")
}
// Find the task and verify ownership
if h.queue != nil {
task, err := h.queue.GetTaskByName(jobName)
if err != nil {
h.logger.Error("task not found", "job", jobName, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
}
// Check if user can cancel this task (admin or owner)
if h.authConfig.Enabled && !user.Admin && task.UserID != user.Name && task.CreatedBy != user.Name {
h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID)
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "")
}
// Cancel the task
if err := h.queue.CancelTask(task.ID); err != nil {
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
}
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
}
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][prune_type:1][value:4]
if len(payload) < 69 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
}
// Parse 64-byte hex API key hash
apiKeyHash := string(payload[0:64])
pruneType := payload[64]
value := binary.BigEndian.Uint32(payload[65:69])
h.logger.Info("prune request", "type", pruneType, "value", value)
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
h.logger.Error("api key verification failed", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
// Convert prune parameters
var keepCount int
var olderThanDays int
switch pruneType {
case 0:
// keep N
keepCount = int(value)
olderThanDays = 0
case 1:
// older than days
keepCount = 0
olderThanDays = int(value)
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "")
}
// Perform pruning
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
if err != nil {
h.logger.Error("prune failed", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
}
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
// Send structured success response
packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))
return h.sendResponsePacket(conn, packet)
}
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
if len(payload) < 141 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
step := int(binary.BigEndian.Uint32(payload[128:132]))
valueBits := binary.BigEndian.Uint64(payload[132:140])
value := math.Float64frombits(valueBits)
nameLen := int(payload[140])
if len(payload) < 141+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
}
name := string(payload[141 : 141+nameLen])
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
if err := h.expManager.LogMetric(commitID, name, value, step); err != nil {
h.logger.Error("failed to log metric", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
}
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64]
if len(payload) < 128 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
meta, err := h.expManager.ReadMetadata(commitID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
}
metrics, err := h.expManager.GetMetrics(commitID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
}
response := map[string]interface{}{
"metadata": meta,
"metrics": metrics,
}
return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response))
}
// SetupTLSConfig creates TLS configuration for WebSocket server
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
var server *http.Server
if certFile != "" && keyFile != "" {
// Use provided certificates
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
},
}
} else if host != "" {
// Use Let's Encrypt with autocert
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(host),
Cache: autocert.DirCache("/var/www/.cache"),
}
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
TLSConfig: certManager.TLSConfig(),
}
}
return server, nil
}
// verifyAPIKeyHash verifies the provided hex hash against stored API keys
func (h *WSHandler) verifyAPIKeyHash(hexHash string) error {
if h.authConfig == nil || !h.authConfig.Enabled {
return nil // No auth required
}
// For now, just check if it's a valid 64-char hex string
if len(hexHash) != 64 {
return fmt.Errorf("invalid api key hash length")
}
// Check against stored API keys
for username, entry := range h.authConfig.APIKeys {
if string(entry.Hash) == hexHash {
_ = username // Username found but not needed for verification
return nil // Valid API key found
}
}
return fmt.Errorf("invalid api key")
}
// sendErrorPacket sends an error response packet
func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error {
packet := NewErrorPacket(errorCode, message, details)
return h.sendResponsePacket(conn, packet)
}
// sendResponsePacket sends a structured response packet
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
data, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize response packet", "error", err)
// Fallback to simple error response
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
}
return conn.WriteMessage(websocket.BinaryMessage, data)
}
// sendErrorResponse removed (unused)