208 lines
6.1 KiB
Go
208 lines
6.1 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
)
|
|
|
|
func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16]
|
|
if len(payload) < 16 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset list payload too short", "")
|
|
}
|
|
apiKeyHash := payload[:16]
|
|
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
if h.db == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
datasets, err := h.db.ListDatasets(ctx, 0)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to list datasets", err.Error())
|
|
}
|
|
|
|
data, err := json.Marshal(datasets)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServerOverloaded,
|
|
"Failed to serialize response",
|
|
err.Error(),
|
|
)
|
|
}
|
|
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
|
|
}
|
|
|
|
func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][name_len:1][name:var][url_len:2][url:var]
|
|
if len(payload) < 16+1+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset register payload too short", "")
|
|
}
|
|
apiKeyHash := payload[:16]
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
if h.db == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
|
}
|
|
|
|
offset := 16
|
|
nameLen := int(payload[offset])
|
|
offset++
|
|
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
|
|
}
|
|
name := string(payload[offset : offset+nameLen])
|
|
offset += nameLen
|
|
|
|
urlLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
|
offset += 2
|
|
if urlLen <= 0 || len(payload) < offset+urlLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url length", "")
|
|
}
|
|
urlStr := string(payload[offset : offset+urlLen])
|
|
|
|
// Minimal validation (server-side authoritative): name non-empty and url parseable.
|
|
if strings.TrimSpace(name) == "" {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "")
|
|
}
|
|
if u, err := url.Parse(urlStr); err != nil || u.Scheme == "" {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "")
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to register dataset", err.Error())
|
|
}
|
|
return h.sendResponsePacket(conn, NewSuccessPacket("Dataset registered"))
|
|
}
|
|
|
|
func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][name_len:1][name:var]
|
|
if len(payload) < 16+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "")
|
|
}
|
|
apiKeyHash := payload[:16]
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
if h.db == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
|
}
|
|
|
|
offset := 16
|
|
nameLen := int(payload[offset])
|
|
offset++
|
|
if nameLen <= 0 || len(payload) < offset+nameLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
|
|
}
|
|
name := string(payload[offset : offset+nameLen])
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
ds, err := h.db.GetDataset(ctx, name)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Dataset not found", "")
|
|
}
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to get dataset", err.Error())
|
|
}
|
|
|
|
data, err := json.Marshal(ds)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServerOverloaded,
|
|
"Failed to serialize response",
|
|
err.Error(),
|
|
)
|
|
}
|
|
return h.sendResponsePacket(conn, NewDataPacket("dataset", data))
|
|
}
|
|
|
|
func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
|
|
// Protocol: [api_key_hash:16][term_len:1][term:var]
|
|
if len(payload) < 16+1 {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "")
|
|
}
|
|
apiKeyHash := payload[:16]
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeAuthenticationFailed,
|
|
"Authentication failed",
|
|
err.Error(),
|
|
)
|
|
}
|
|
}
|
|
if h.db == nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
|
}
|
|
|
|
offset := 16
|
|
termLen := int(payload[offset])
|
|
offset++
|
|
if termLen < 0 || len(payload) < offset+termLen {
|
|
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid search term length", "")
|
|
}
|
|
term := string(payload[offset : offset+termLen])
|
|
term = strings.TrimSpace(term)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
datasets, err := h.db.SearchDatasets(ctx, term, 0)
|
|
if err != nil {
|
|
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to search datasets", err.Error())
|
|
}
|
|
|
|
data, err := json.Marshal(datasets)
|
|
if err != nil {
|
|
return h.sendErrorPacket(
|
|
conn,
|
|
ErrorCodeServerOverloaded,
|
|
"Failed to serialize response",
|
|
err.Error(),
|
|
)
|
|
}
|
|
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
|
|
}
|