fetch_ml/internal/api/ws_jupyter.go

478 lines
15 KiB
Go

package api
import (
"encoding/binary"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/queue"
)
type jupyterTaskOutput struct {
Type string `json:"type"`
Service json.RawMessage `json:"service,omitempty"`
Services json.RawMessage `json:"services,omitempty"`
RestorePath string `json:"restore_path,omitempty"`
}
func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name := string(payload[offset : offset+nameLen])
meta := map[string]string{
jupyterTaskActionKey: jupyterActionRestore,
jupyterNameKey: strings.TrimSpace(name),
}
jobName := fmt.Sprintf("jupyter-restore-%s", strings.TrimSpace(name))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter restore", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter restore", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter restore", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
}
msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name))
out := strings.TrimSpace(result.Output)
if out != "" {
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
if strings.TrimSpace(payloadOut.RestorePath) != "" {
msg = fmt.Sprintf("Restored Jupyter workspace '%s' to %s", strings.TrimSpace(name), strings.TrimSpace(payloadOut.RestorePath))
}
}
}
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
}
type jupyterServiceView struct {
Name string `json:"name"`
URL string `json:"url"`
}
const (
jupyterTaskTypeKey = "task_type"
jupyterTaskTypeValue = "jupyter"
jupyterTaskActionKey = "jupyter_action"
jupyterActionStart = "start"
jupyterActionStop = "stop"
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterNameKey = "jupyter_name"
jupyterWorkspaceKey = "jupyter_workspace"
jupyterServiceIDKey = "jupyter_service_id"
)
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
if h.queue == nil {
return "", fmt.Errorf("task queue not configured")
}
if err := container.ValidateJobName(jobName); err != nil {
return "", err
}
if strings.TrimSpace(userName) == "" {
return "", fmt.Errorf("missing user")
}
if meta == nil {
meta = make(map[string]string)
}
meta[jupyterTaskTypeKey] = jupyterTaskTypeValue
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "",
Status: "queued",
Priority: 100, // high priority; interactive request
CreatedAt: time.Now(),
UserID: userName,
Username: userName,
CreatedBy: userName,
Metadata: meta,
}
if err := h.queue.AddTask(task); err != nil {
return "", err
}
return taskID, nil
}
func (h *WSHandler) waitForTask(taskID string, timeout time.Duration) (*queue.Task, error) {
if h.queue == nil {
return nil, fmt.Errorf("task queue not configured")
}
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("timed out waiting for worker")
}
t, err := h.queue.GetTask(taskID)
if err != nil {
return nil, err
}
if t == nil {
time.Sleep(200 * time.Millisecond)
continue
}
if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" {
return t, nil
}
time.Sleep(200 * time.Millisecond)
}
}
func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol:
// [api_key_hash:16][name][workspace][password]
if len(payload) < 21 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "")
}
apiKeyHash := payload[:16]
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name := string(payload[offset : offset+nameLen])
offset += nameLen
workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if len(payload) < offset+workspaceLen+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "")
}
workspace := string(payload[offset : offset+workspaceLen])
offset += workspaceLen
passwordLen := int(payload[offset])
offset++
if len(payload) < offset+passwordLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid password length", "")
}
// Password is parsed but not used in StartRequest
// offset += passwordLen (already advanced during parsing)
meta := map[string]string{
jupyterTaskActionKey: jupyterActionStart,
jupyterNameKey: strings.TrimSpace(name),
jupyterWorkspaceKey: strings.TrimSpace(workspace),
}
jobName := fmt.Sprintf("jupyter-%s", strings.TrimSpace(name))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter task", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter task", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter start", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
h.logger.Error("jupyter task failed", "error", result.Error)
details := strings.TrimSpace(result.Error)
lower := strings.ToLower(details)
if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") {
return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details)
}
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to start Jupyter service", details)
}
msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name))
out := strings.TrimSpace(result.Output)
if out != "" {
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil && len(payloadOut.Service) > 0 {
var svc jupyterServiceView
if err := json.Unmarshal(payloadOut.Service, &svc); err == nil {
if strings.TrimSpace(svc.URL) != "" {
msg = fmt.Sprintf("Started Jupyter service '%s' at %s", strings.TrimSpace(name), strings.TrimSpace(svc.URL))
}
}
}
}
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
}
func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
idLen := int(payload[offset])
offset++
if len(payload) < offset+idLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
}
serviceID := string(payload[offset : offset+idLen])
meta := map[string]string{
jupyterTaskActionKey: jupyterActionStop,
jupyterServiceIDKey: strings.TrimSpace(serviceID),
}
jobName := fmt.Sprintf("jupyter-stop-%s", strings.TrimSpace(serviceID))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter stop", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter stop", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter stop", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
}
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID)))
}
func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "")
}
apiKeyHash := payload[:16]
offset := 16
idLen := int(payload[offset])
offset++
if len(payload) < offset+idLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
}
serviceID := string(payload[offset : offset+idLen])
offset += idLen
// Optional: purge flag (1 byte). Default false for trash-first behavior.
purge := false
if len(payload) > offset {
purge = payload[offset] == 0x01
}
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionRemove,
jupyterServiceIDKey: strings.TrimSpace(serviceID),
"jupyter_purge": fmt.Sprintf("%t", purge),
}
jobName := fmt.Sprintf("jupyter-remove-%s", strings.TrimSpace(serviceID))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter remove", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter remove", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter remove", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
}
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID)))
}
func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16]
if len(payload) < 16 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:read") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionList,
}
jobName := fmt.Sprintf("jupyter-list-%s", user.Name)
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter list", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to list Jupyter services", strings.TrimSpace(result.Error))
}
out := strings.TrimSpace(result.Output)
if out == "" {
empty, _ := json.Marshal([]any{})
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", empty))
}
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
// Always return an array payload (even if empty) so clients can render a stable table.
payload := payloadOut.Services
if len(payload) == 0 {
payload = []byte("[]")
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload))
}
// Fallback: return empty array on unexpected output.
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", []byte("[]")))
}