fetch_ml/internal/api/ws_datasets.go

208 lines
6.1 KiB
Go

package api
import (
"context"
"database/sql"
"encoding/binary"
"encoding/json"
"net/url"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/storage"
)
func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16]
if len(payload) < 16 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset list payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
datasets, err := h.db.ListDatasets(ctx, 0)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to list datasets", err.Error())
}
data, err := json.Marshal(datasets)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
}
func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var][url_len:2][url:var]
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset register payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
}
name := string(payload[offset : offset+nameLen])
offset += nameLen
urlLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if urlLen <= 0 || len(payload) < offset+urlLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url length", "")
}
urlStr := string(payload[offset : offset+urlLen])
// Minimal validation (server-side authoritative): name non-empty and url parseable.
if strings.TrimSpace(name) == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "")
}
if u, err := url.Parse(urlStr); err != nil || u.Scheme == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "")
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to register dataset", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Dataset registered"))
}
func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
}
name := string(payload[offset : offset+nameLen])
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
ds, err := h.db.GetDataset(ctx, name)
if err != nil {
if err == sql.ErrNoRows {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Dataset not found", "")
}
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to get dataset", err.Error())
}
data, err := json.Marshal(ds)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("dataset", data))
}
func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][term_len:1][term:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
offset := 16
termLen := int(payload[offset])
offset++
if termLen < 0 || len(payload) < offset+termLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid search term length", "")
}
term := string(payload[offset : offset+termLen])
term = strings.TrimSpace(term)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
datasets, err := h.db.SearchDatasets(ctx, term, 0)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to search datasets", err.Error())
}
data, err := json.Marshal(datasets)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
}