- Add API server with WebSocket support and REST endpoints - Implement authentication system with API keys and permissions - Add task queue system with Redis backend and error handling - Include storage layer with database migrations and schemas - Add comprehensive logging, metrics, and telemetry - Implement security middleware and network utilities - Add experiment management and container orchestration - Include configuration management with smart defaults
606 lines
19 KiB
Go
606 lines
19 KiB
Go
package api
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"math"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"golang.org/x/crypto/acme/autocert"
|
|
)
|
|
|
|
// Opcodes for binary WebSocket protocol
|
|
const (
|
|
OpcodeQueueJob = 0x01
|
|
OpcodeStatusRequest = 0x02
|
|
OpcodeCancelJob = 0x03
|
|
OpcodePrune = 0x04
|
|
OpcodeLogMetric = 0x0A
|
|
OpcodeGetExperiment = 0x0B
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
// Allow localhost and homelab origins for development
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // Allow same-origin requests
|
|
}
|
|
|
|
// Parse origin URL
|
|
parsedOrigin, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Allow localhost and local network origins
|
|
host := parsedOrigin.Host
|
|
return strings.HasSuffix(host, ":8080") ||
|
|
strings.HasSuffix(host, ":8081") ||
|
|
strings.HasPrefix(host, "localhost") ||
|
|
strings.HasPrefix(host, "127.0.0.1") ||
|
|
strings.HasPrefix(host, "192.168.") ||
|
|
strings.HasPrefix(host, "10.") ||
|
|
strings.HasPrefix(host, "172.")
|
|
},
|
|
}
|
|
|
|
type WSHandler struct {
|
|
authConfig *auth.AuthConfig
|
|
logger *logging.Logger
|
|
expManager *experiment.Manager
|
|
queue *queue.TaskQueue
|
|
}
|
|
|
|
func NewWSHandler(authConfig *auth.AuthConfig, logger *logging.Logger, expManager *experiment.Manager, taskQueue *queue.TaskQueue) *WSHandler {
|
|
return &WSHandler{
|
|
authConfig: authConfig,
|
|
logger: logger,
|
|
expManager: expManager,
|
|
queue: taskQueue,
|
|
}
|
|
}
|
|
|
|
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
// Check API key before upgrading WebSocket
|
|
apiKey := r.Header.Get("X-API-Key")
|
|
if apiKey == "" {
|
|
// Also check Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
|
apiKey = strings.TrimPrefix(authHeader, "Bearer ")
|
|
}
|
|
}
|
|
|
|
// Validate API key if authentication is enabled
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil {
|
|
h.logger.Warn("websocket authentication failed", "error", err)
|
|
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
}
|
|
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
h.logger.Error("websocket upgrade failed", "error", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
|
|
|
|
for {
|
|
messageType, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
h.logger.Error("websocket read error", "error", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
if messageType != websocket.BinaryMessage {
|
|
h.logger.Warn("received non-binary message")
|
|
continue
|
|
}
|
|
|
|
if err := h.handleMessage(conn, message); err != nil {
|
|
h.logger.Error("message handling error", "error", err)
|
|
// Send error response
|
|
_ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
|
if len(message) < 1 {
|
|
return fmt.Errorf("message too short")
|
|
}
|
|
|
|
opcode := message[0]
|
|
payload := message[1:]
|
|
|
|
switch opcode {
|
|
case OpcodeQueueJob:
|
|
return h.handleQueueJob(conn, payload)
|
|
case OpcodeStatusRequest:
|
|
return h.handleStatusRequest(conn, payload)
|
|
case OpcodeCancelJob:
|
|
return h.handleCancelJob(conn, payload)
|
|
case OpcodePrune:
|
|
return h.handlePrune(conn, payload)
|
|
case OpcodeLogMetric:
|
|
return h.handleLogMetric(conn, payload)
|
|
case OpcodeGetExperiment:
|
|
return h.handleGetExperiment(conn, payload)
|
|
default:
|
|
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
|
|
}
|
|
}
|
|
|
|
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
|
|
if len(payload) < 130 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := string(payload[:64])
|
|
commitID := string(payload[64:128])
|
|
priority := int64(payload[128])
|
|
jobNameLen := int(payload[129])
|
|
|
|
if len(payload) < 130+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
|
}
|
|
|
|
jobName := string(payload[130 : 130+jobNameLen])
|
|
|
|
h.logger.Info("queue job request",
|
|
"job", jobName,
|
|
"priority", priority,
|
|
"commit_id", commitID,
|
|
)
|
|
|
|
// Validate API key and get user information
|
|
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
|
|
if err != nil {
|
|
h.logger.Error("invalid api key", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
|
}
|
|
|
|
// Check user permissions
|
|
if !h.authConfig.Enabled || user.HasPermission("jobs:create") {
|
|
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
|
|
} else {
|
|
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "")
|
|
}
|
|
|
|
// Create experiment directory and metadata
|
|
if err := h.expManager.CreateExperiment(commitID); err != nil {
|
|
h.logger.Error("failed to create experiment directory", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
|
|
}
|
|
|
|
// Add user info to experiment metadata
|
|
meta := &experiment.Metadata{
|
|
CommitID: commitID,
|
|
JobName: jobName,
|
|
User: user.Name,
|
|
Timestamp: time.Now().Unix(),
|
|
}
|
|
|
|
if err := h.expManager.WriteMetadata(meta); err != nil {
|
|
h.logger.Error("failed to save experiment metadata", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error())
|
|
}
|
|
|
|
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
|
|
|
|
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
|
|
|
|
// Enqueue task if queue is available
|
|
if h.queue != nil {
|
|
taskID := uuid.New().String()
|
|
task := &queue.Task{
|
|
ID: taskID,
|
|
JobName: jobName,
|
|
Args: "", // TODO: Add args support
|
|
Status: "queued",
|
|
Priority: priority,
|
|
CreatedAt: time.Now(),
|
|
UserID: user.Name,
|
|
Username: user.Name,
|
|
CreatedBy: user.Name,
|
|
Metadata: map[string]string{
|
|
"commit_id": commitID,
|
|
"user_id": user.Name,
|
|
"username": user.Name,
|
|
},
|
|
}
|
|
|
|
if err := h.queue.AddTask(task); err != nil {
|
|
h.logger.Error("failed to enqueue task", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
|
|
}
|
|
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
|
|
} else {
|
|
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
|
|
}
|
|
|
|
packetData, err := packet.Serialize()
|
|
if err != nil {
|
|
h.logger.Error("failed to serialize packet", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
|
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
|
}
|
|
|
|
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:64]
|
|
if len(payload) < 64 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := string(payload[0:64])
|
|
h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...")
|
|
|
|
// Validate API key and get user information
|
|
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
|
|
if err != nil {
|
|
h.logger.Error("invalid api key", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
|
}
|
|
|
|
// Check user permissions for viewing jobs
|
|
if !h.authConfig.Enabled || user.HasPermission("jobs:read") {
|
|
// Continue with status request
|
|
} else {
|
|
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read")
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "")
|
|
}
|
|
|
|
// Get tasks with user filtering
|
|
var tasks []*queue.Task
|
|
if h.queue != nil {
|
|
allTasks, err := h.queue.GetAllTasks()
|
|
if err != nil {
|
|
h.logger.Error("failed to get tasks", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
|
|
}
|
|
|
|
// Filter tasks based on user permissions
|
|
for _, task := range allTasks {
|
|
// If auth is disabled or admin can see all tasks
|
|
if !h.authConfig.Enabled || user.Admin {
|
|
tasks = append(tasks, task)
|
|
continue
|
|
}
|
|
|
|
// Users can only see their own tasks
|
|
if task.UserID == user.Name || task.CreatedBy == user.Name {
|
|
tasks = append(tasks, task)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build status response with user-specific data
|
|
status := map[string]interface{}{
|
|
"user": map[string]interface{}{
|
|
"name": user.Name,
|
|
"admin": user.Admin,
|
|
"roles": user.Roles,
|
|
},
|
|
"tasks": map[string]interface{}{
|
|
"total": len(tasks),
|
|
"queued": countTasksByStatus(tasks, "queued"),
|
|
"running": countTasksByStatus(tasks, "running"),
|
|
"failed": countTasksByStatus(tasks, "failed"),
|
|
"completed": countTasksByStatus(tasks, "completed"),
|
|
},
|
|
"queue": tasks, // Include filtered tasks
|
|
}
|
|
|
|
packet := NewSuccessPacketWithPayload("Status retrieved", status)
|
|
packetData, err := packet.Serialize()
|
|
if err != nil {
|
|
h.logger.Error("failed to serialize packet", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
|
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
|
}
|
|
|
|
// countTasksByStatus counts tasks by their status
|
|
func countTasksByStatus(tasks []*queue.Task, status string) int {
|
|
count := 0
|
|
for _, task := range tasks {
|
|
if task.Status == status {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:64][job_name_len:1][job_name:var]
|
|
if len(payload) < 65 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "")
|
|
}
|
|
|
|
// Parse 64-byte hex API key hash
|
|
apiKeyHash := string(payload[0:64])
|
|
jobNameLen := int(payload[64])
|
|
|
|
if len(payload) < 65+jobNameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
|
}
|
|
|
|
jobName := string(payload[65 : 65+jobNameLen])
|
|
|
|
h.logger.Info("cancel job request", "job", jobName)
|
|
|
|
// Validate API key and get user information
|
|
user, err := h.authConfig.ValidateAPIKey(apiKeyHash)
|
|
if err != nil {
|
|
h.logger.Error("invalid api key", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
|
}
|
|
|
|
// Check user permissions for canceling jobs
|
|
if !h.authConfig.Enabled || user.HasPermission("jobs:update") {
|
|
// Continue with cancel request
|
|
} else {
|
|
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "")
|
|
}
|
|
|
|
// Find the task and verify ownership
|
|
if h.queue != nil {
|
|
task, err := h.queue.GetTaskByName(jobName)
|
|
if err != nil {
|
|
h.logger.Error("task not found", "job", jobName, "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
|
|
}
|
|
|
|
// Check if user can cancel this task (admin or owner)
|
|
if !h.authConfig.Enabled || user.Admin || task.UserID == user.Name || task.CreatedBy == user.Name {
|
|
// User can cancel the task
|
|
} else {
|
|
h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID)
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "")
|
|
}
|
|
|
|
// Cancel the task
|
|
if err := h.queue.CancelTask(task.ID); err != nil {
|
|
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
|
|
}
|
|
|
|
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
|
|
} else {
|
|
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
|
|
}
|
|
|
|
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))
|
|
packetData, err := packet.Serialize()
|
|
if err != nil {
|
|
h.logger.Error("failed to serialize packet", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
|
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
|
}
|
|
|
|
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:64][prune_type:1][value:4]
|
|
if len(payload) < 69 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
|
|
}
|
|
|
|
// Parse 64-byte hex API key hash
|
|
apiKeyHash := string(payload[0:64])
|
|
pruneType := payload[64]
|
|
value := binary.BigEndian.Uint32(payload[65:69])
|
|
|
|
h.logger.Info("prune request", "type", pruneType, "value", value)
|
|
|
|
// Verify API key
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
h.logger.Error("api key verification failed", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
|
}
|
|
}
|
|
|
|
// Convert prune parameters
|
|
var keepCount int
|
|
var olderThanDays int
|
|
|
|
switch pruneType {
|
|
case 0:
|
|
// keep N
|
|
keepCount = int(value)
|
|
olderThanDays = 0
|
|
case 1:
|
|
// older than days
|
|
keepCount = 0
|
|
olderThanDays = int(value)
|
|
default:
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "")
|
|
}
|
|
|
|
// Perform pruning
|
|
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
|
|
if err != nil {
|
|
h.logger.Error("prune failed", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
|
|
}
|
|
|
|
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
|
|
|
|
// Send structured success response
|
|
packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))
|
|
return h.sendResponsePacket(conn, packet)
|
|
}
|
|
|
|
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
|
|
if len(payload) < 141 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := string(payload[:64])
|
|
commitID := string(payload[64:128])
|
|
step := int(binary.BigEndian.Uint32(payload[128:132]))
|
|
valueBits := binary.BigEndian.Uint64(payload[132:140])
|
|
value := math.Float64frombits(valueBits)
|
|
nameLen := int(payload[140])
|
|
|
|
if len(payload) < 141+nameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
|
|
}
|
|
|
|
name := string(payload[141 : 141+nameLen])
|
|
|
|
// Verify API key
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
|
}
|
|
}
|
|
|
|
if err := h.expManager.LogMetric(commitID, name, value, step); err != nil {
|
|
h.logger.Error("failed to log metric", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
|
|
}
|
|
|
|
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:64][commit_id:64]
|
|
if len(payload) < 128 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := string(payload[:64])
|
|
commitID := string(payload[64:128])
|
|
|
|
// Verify API key
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
|
}
|
|
}
|
|
|
|
meta, err := h.expManager.ReadMetadata(commitID)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
|
|
}
|
|
|
|
metrics, err := h.expManager.GetMetrics(commitID)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
|
|
}
|
|
|
|
response := map[string]interface{}{
|
|
"metadata": meta,
|
|
"metrics": metrics,
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response))
|
|
}
|
|
|
|
// Helper to hash API key for comparison
|
|
func HashAPIKey(apiKey string) string {
|
|
hash := sha256.Sum256([]byte(apiKey))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// SetupTLSConfig creates TLS configuration for WebSocket server
|
|
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
|
|
var server *http.Server
|
|
|
|
if certFile != "" && keyFile != "" {
|
|
// Use provided certificates
|
|
server = &http.Server{
|
|
TLSConfig: &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
CipherSuites: []uint16{
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
|
},
|
|
},
|
|
}
|
|
} else if host != "" {
|
|
// Use Let's Encrypt with autocert
|
|
certManager := &autocert.Manager{
|
|
Prompt: autocert.AcceptTOS,
|
|
HostPolicy: autocert.HostWhitelist(host),
|
|
Cache: autocert.DirCache("/var/www/.cache"),
|
|
}
|
|
|
|
server = &http.Server{
|
|
TLSConfig: certManager.TLSConfig(),
|
|
}
|
|
}
|
|
|
|
return server, nil
|
|
}
|
|
|
|
// verifyAPIKeyHash verifies the provided hex hash against stored API keys
|
|
func (h *WSHandler) verifyAPIKeyHash(hexHash string) error {
|
|
if h.authConfig == nil || !h.authConfig.Enabled {
|
|
return nil // No auth required
|
|
}
|
|
|
|
// For now, just check if it's a valid 64-char hex string
|
|
if len(hexHash) != 64 {
|
|
return fmt.Errorf("invalid api key hash length")
|
|
}
|
|
|
|
// Check against stored API keys
|
|
for username, entry := range h.authConfig.APIKeys {
|
|
if string(entry.Hash) == hexHash {
|
|
_ = username // Username found but not needed for verification
|
|
return nil // Valid API key found
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("invalid api key")
|
|
}
|
|
|
|
// sendErrorPacket sends an error response packet
|
|
func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error {
|
|
packet := NewErrorPacket(errorCode, message, details)
|
|
return h.sendResponsePacket(conn, packet)
|
|
}
|
|
|
|
// sendResponsePacket sends a structured response packet
|
|
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
|
|
data, err := packet.Serialize()
|
|
if err != nil {
|
|
h.logger.Error("failed to serialize response packet", "error", err)
|
|
// Fallback to simple error response
|
|
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
|
|
}
|
|
|
|
return conn.WriteMessage(websocket.BinaryMessage, data)
|
|
}
|
|
|
|
// sendErrorResponse removed (unused)
|