478 lines
15 KiB
Go
478 lines
15 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
)
|
|
|
|
type jupyterTaskOutput struct {
|
|
Type string `json:"type"`
|
|
Service json.RawMessage `json:"service,omitempty"`
|
|
Services json.RawMessage `json:"services,omitempty"`
|
|
RestorePath string `json:"restore_path,omitempty"`
|
|
}
|
|
|
|
func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][name_len:1][name:var]
|
|
if len(payload) < 18 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := payload[:16]
|
|
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
user, err := h.validateWSUser(apiKeyHash)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
if user != nil && !user.HasPermission("jupyter:manage") {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
|
}
|
|
|
|
offset := 16
|
|
nameLen := int(payload[offset])
|
|
offset++
|
|
if len(payload) < offset+nameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
|
}
|
|
name := string(payload[offset : offset+nameLen])
|
|
|
|
meta := map[string]string{
|
|
jupyterTaskActionKey: jupyterActionRestore,
|
|
jupyterNameKey: strings.TrimSpace(name),
|
|
}
|
|
jobName := fmt.Sprintf("jupyter-restore-%s", strings.TrimSpace(name))
|
|
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
|
if err != nil {
|
|
h.logger.Error("failed to enqueue jupyter restore", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter restore", "")
|
|
}
|
|
|
|
result, err := h.waitForTask(taskID, 2*time.Minute)
|
|
if err != nil {
|
|
h.logger.Error("failed waiting for jupyter restore", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
|
}
|
|
if result.Status != "completed" {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
|
|
}
|
|
|
|
msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name))
|
|
out := strings.TrimSpace(result.Output)
|
|
if out != "" {
|
|
var payloadOut jupyterTaskOutput
|
|
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
|
if strings.TrimSpace(payloadOut.RestorePath) != "" {
|
|
msg = fmt.Sprintf("Restored Jupyter workspace '%s' to %s", strings.TrimSpace(name), strings.TrimSpace(payloadOut.RestorePath))
|
|
}
|
|
}
|
|
}
|
|
|
|
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
|
|
}
|
|
|
|
type jupyterServiceView struct {
|
|
Name string `json:"name"`
|
|
URL string `json:"url"`
|
|
}
|
|
|
|
const (
|
|
jupyterTaskTypeKey = "task_type"
|
|
jupyterTaskTypeValue = "jupyter"
|
|
|
|
jupyterTaskActionKey = "jupyter_action"
|
|
jupyterActionStart = "start"
|
|
jupyterActionStop = "stop"
|
|
jupyterActionRemove = "remove"
|
|
jupyterActionRestore = "restore"
|
|
jupyterActionList = "list"
|
|
|
|
jupyterNameKey = "jupyter_name"
|
|
jupyterWorkspaceKey = "jupyter_workspace"
|
|
jupyterServiceIDKey = "jupyter_service_id"
|
|
)
|
|
|
|
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
|
|
if h.queue == nil {
|
|
return "", fmt.Errorf("task queue not configured")
|
|
}
|
|
if err := container.ValidateJobName(jobName); err != nil {
|
|
return "", err
|
|
}
|
|
if strings.TrimSpace(userName) == "" {
|
|
return "", fmt.Errorf("missing user")
|
|
}
|
|
if meta == nil {
|
|
meta = make(map[string]string)
|
|
}
|
|
meta[jupyterTaskTypeKey] = jupyterTaskTypeValue
|
|
|
|
taskID := uuid.New().String()
|
|
task := &queue.Task{
|
|
ID: taskID,
|
|
JobName: jobName,
|
|
Args: "",
|
|
Status: "queued",
|
|
Priority: 100, // high priority; interactive request
|
|
CreatedAt: time.Now(),
|
|
UserID: userName,
|
|
Username: userName,
|
|
CreatedBy: userName,
|
|
Metadata: meta,
|
|
}
|
|
if err := h.queue.AddTask(task); err != nil {
|
|
return "", err
|
|
}
|
|
return taskID, nil
|
|
}
|
|
|
|
func (h *WSHandler) waitForTask(taskID string, timeout time.Duration) (*queue.Task, error) {
|
|
if h.queue == nil {
|
|
return nil, fmt.Errorf("task queue not configured")
|
|
}
|
|
deadline := time.Now().Add(timeout)
|
|
for {
|
|
if time.Now().After(deadline) {
|
|
return nil, fmt.Errorf("timed out waiting for worker")
|
|
}
|
|
t, err := h.queue.GetTask(taskID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if t == nil {
|
|
time.Sleep(200 * time.Millisecond)
|
|
continue
|
|
}
|
|
if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" {
|
|
return t, nil
|
|
}
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
}
|
|
|
|
func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol:
|
|
// [api_key_hash:16][name][workspace][password]
|
|
if len(payload) < 21 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := payload[:16]
|
|
|
|
// Verify API key
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
user, err := h.validateWSUser(apiKeyHash)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
if user != nil && !user.HasPermission("jupyter:manage") {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
|
}
|
|
|
|
offset := 16
|
|
nameLen := int(payload[offset])
|
|
offset++
|
|
|
|
if len(payload) < offset+nameLen+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
|
}
|
|
|
|
name := string(payload[offset : offset+nameLen])
|
|
offset += nameLen
|
|
|
|
workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
|
offset += 2
|
|
|
|
if len(payload) < offset+workspaceLen+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "")
|
|
}
|
|
|
|
workspace := string(payload[offset : offset+workspaceLen])
|
|
offset += workspaceLen
|
|
|
|
passwordLen := int(payload[offset])
|
|
offset++
|
|
|
|
if len(payload) < offset+passwordLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid password length", "")
|
|
}
|
|
|
|
// Password is parsed but not used in StartRequest
|
|
// offset += passwordLen (already advanced during parsing)
|
|
|
|
meta := map[string]string{
|
|
jupyterTaskActionKey: jupyterActionStart,
|
|
jupyterNameKey: strings.TrimSpace(name),
|
|
jupyterWorkspaceKey: strings.TrimSpace(workspace),
|
|
}
|
|
jobName := fmt.Sprintf("jupyter-%s", strings.TrimSpace(name))
|
|
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
|
if err != nil {
|
|
h.logger.Error("failed to enqueue jupyter task", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter task", "")
|
|
}
|
|
|
|
result, err := h.waitForTask(taskID, 2*time.Minute)
|
|
if err != nil {
|
|
h.logger.Error("failed waiting for jupyter start", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
|
}
|
|
if result.Status != "completed" {
|
|
h.logger.Error("jupyter task failed", "error", result.Error)
|
|
details := strings.TrimSpace(result.Error)
|
|
lower := strings.ToLower(details)
|
|
if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details)
|
|
}
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to start Jupyter service", details)
|
|
}
|
|
|
|
msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name))
|
|
out := strings.TrimSpace(result.Output)
|
|
if out != "" {
|
|
var payloadOut jupyterTaskOutput
|
|
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil && len(payloadOut.Service) > 0 {
|
|
var svc jupyterServiceView
|
|
if err := json.Unmarshal(payloadOut.Service, &svc); err == nil {
|
|
if strings.TrimSpace(svc.URL) != "" {
|
|
msg = fmt.Sprintf("Started Jupyter service '%s' at %s", strings.TrimSpace(name), strings.TrimSpace(svc.URL))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
|
|
}
|
|
|
|
func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
|
|
if len(payload) < 18 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := payload[:16]
|
|
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
user, err := h.validateWSUser(apiKeyHash)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
if user != nil && !user.HasPermission("jupyter:manage") {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
|
}
|
|
|
|
offset := 16
|
|
idLen := int(payload[offset])
|
|
offset++
|
|
|
|
if len(payload) < offset+idLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
|
|
}
|
|
|
|
serviceID := string(payload[offset : offset+idLen])
|
|
|
|
meta := map[string]string{
|
|
jupyterTaskActionKey: jupyterActionStop,
|
|
jupyterServiceIDKey: strings.TrimSpace(serviceID),
|
|
}
|
|
jobName := fmt.Sprintf("jupyter-stop-%s", strings.TrimSpace(serviceID))
|
|
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
|
if err != nil {
|
|
h.logger.Error("failed to enqueue jupyter stop", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter stop", "")
|
|
}
|
|
result, err := h.waitForTask(taskID, 2*time.Minute)
|
|
if err != nil {
|
|
h.logger.Error("failed waiting for jupyter stop", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
|
}
|
|
if result.Status != "completed" {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
|
|
}
|
|
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID)))
|
|
}
|
|
|
|
func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
|
|
if len(payload) < 18 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := payload[:16]
|
|
|
|
offset := 16
|
|
idLen := int(payload[offset])
|
|
offset++
|
|
if len(payload) < offset+idLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
|
|
}
|
|
serviceID := string(payload[offset : offset+idLen])
|
|
offset += idLen
|
|
|
|
// Optional: purge flag (1 byte). Default false for trash-first behavior.
|
|
purge := false
|
|
if len(payload) > offset {
|
|
purge = payload[offset] == 0x01
|
|
}
|
|
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
user, err := h.validateWSUser(apiKeyHash)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
if user != nil && !user.HasPermission("jupyter:manage") {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
|
}
|
|
|
|
meta := map[string]string{
|
|
jupyterTaskActionKey: jupyterActionRemove,
|
|
jupyterServiceIDKey: strings.TrimSpace(serviceID),
|
|
"jupyter_purge": fmt.Sprintf("%t", purge),
|
|
}
|
|
jobName := fmt.Sprintf("jupyter-remove-%s", strings.TrimSpace(serviceID))
|
|
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
|
if err != nil {
|
|
h.logger.Error("failed to enqueue jupyter remove", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter remove", "")
|
|
}
|
|
|
|
result, err := h.waitForTask(taskID, 2*time.Minute)
|
|
if err != nil {
|
|
h.logger.Error("failed waiting for jupyter remove", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
|
}
|
|
if result.Status != "completed" {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
|
|
}
|
|
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID)))
|
|
}
|
|
|
|
func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16]
|
|
if len(payload) < 16 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter payload too short", "")
|
|
}
|
|
|
|
apiKeyHash := payload[:16]
|
|
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
user, err := h.validateWSUser(apiKeyHash)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
if user != nil && !user.HasPermission("jupyter:read") {
|
|
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
|
}
|
|
|
|
meta := map[string]string{
|
|
jupyterTaskActionKey: jupyterActionList,
|
|
}
|
|
jobName := fmt.Sprintf("jupyter-list-%s", user.Name)
|
|
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
|
if err != nil {
|
|
h.logger.Error("failed to enqueue jupyter list", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter list", "")
|
|
}
|
|
result, err := h.waitForTask(taskID, 2*time.Minute)
|
|
if err != nil {
|
|
h.logger.Error("failed waiting for jupyter list", "error", err)
|
|
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
|
}
|
|
if result.Status != "completed" {
|
|
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to list Jupyter services", strings.TrimSpace(result.Error))
|
|
}
|
|
|
|
out := strings.TrimSpace(result.Output)
|
|
if out == "" {
|
|
empty, _ := json.Marshal([]any{})
|
|
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", empty))
|
|
}
|
|
var payloadOut jupyterTaskOutput
|
|
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
|
// Always return an array payload (even if empty) so clients can render a stable table.
|
|
payload := payloadOut.Services
|
|
if len(payload) == 0 {
|
|
payload = []byte("[]")
|
|
}
|
|
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload))
|
|
}
|
|
// Fallback: return empty array on unexpected output.
|
|
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", []byte("[]")))
|
|
}
|