- Fix YAML tags in auth config struct (json -> yaml) - Update CLI configs to use pre-hashed API keys - Remove double hashing in WebSocket client - Fix port mapping (9102 -> 9103) in CLI commands - Update permission keys to use jobs:read, jobs:create, etc. - Clean up all debug logging from CLI and server - All user roles now authenticate correctly: * Admin: Can queue jobs and see all jobs * Researcher: Can queue jobs and see own jobs * Analyst: Can see status (read-only access) Multi-user authentication is now fully functional.
261 lines
6.8 KiB
Go
261 lines
6.8 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/network"
|
|
)
|
|
|
|
func TestSSHClient_ExecContext(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // Reduced from 5 seconds
|
|
defer cancel()
|
|
|
|
out, err := client.ExecContext(ctx, "echo 'test'")
|
|
if err != nil {
|
|
t.Errorf("ExecContext failed: %v", err)
|
|
}
|
|
|
|
if out != "test\n" {
|
|
t.Errorf("Expected 'test\\n', got %q", out)
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_RemoteExists(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
dir := t.TempDir()
|
|
file := filepath.Join(dir, "exists.txt")
|
|
if writeErr := os.WriteFile(file, []byte("data"), 0o600); writeErr != nil {
|
|
t.Fatalf("failed to create temp file: %v", writeErr)
|
|
}
|
|
|
|
if !client.RemoteExists(file) {
|
|
t.Fatal("expected RemoteExists to return true for existing file")
|
|
}
|
|
|
|
missing := filepath.Join(dir, "missing.txt")
|
|
if client.RemoteExists(missing) {
|
|
t.Fatal("expected RemoteExists to return false for missing file")
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_GetFileSizeError(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
if _, err := client.GetFileSize("/path/that/does/not/exist"); err == nil {
|
|
t.Fatal("expected GetFileSize to error for missing path")
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_TailFileMissingReturnsEmpty(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
if out := client.TailFile("/path/that/does/not/exist", 5); out != "" {
|
|
t.Fatalf("expected empty TailFile output for missing file, got %q", out)
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_ExecContextCancellationDuringRun(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, runErr := client.ExecContext(ctx, "sleep 5")
|
|
done <- runErr
|
|
}()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err == nil {
|
|
t.Fatal("expected cancellation error, got nil")
|
|
}
|
|
if !errors.Is(err, context.Canceled) && !strings.Contains(err.Error(), "signal: killed") {
|
|
t.Fatalf("expected context cancellation or killed signal, got %v", err)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("ExecContext did not return after cancellation")
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_ContextCancellation(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, _ := network.NewSSHClient("", "", "", 0, "")
|
|
defer func() { _ = client.Close() }()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel() // Cancel immediately
|
|
|
|
_, err := client.ExecContext(ctx, "sleep 10")
|
|
if err == nil {
|
|
t.Error("Expected error from cancelled context")
|
|
}
|
|
|
|
// Check that it's a context cancellation error
|
|
if !strings.Contains(err.Error(), "context canceled") {
|
|
t.Errorf("Expected context cancellation error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_LocalMode(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
// Test basic command
|
|
out, err := client.Exec("pwd")
|
|
if err != nil {
|
|
t.Errorf("Exec failed: %v", err)
|
|
}
|
|
|
|
if out == "" {
|
|
t.Error("Expected non-empty output from pwd")
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_NewLocalClient(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
basePath := t.TempDir()
|
|
client := network.NewLocalClient(basePath)
|
|
defer func() { _ = client.Close() }()
|
|
|
|
// Verify client is in local mode
|
|
if client.Host() != "localhost" {
|
|
t.Errorf("Expected host 'localhost', got %q", client.Host())
|
|
}
|
|
|
|
// Test that commands execute in the base path
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
// Create a test file in the base path
|
|
testFile := filepath.Join(basePath, "local_test.txt")
|
|
if err := os.WriteFile(testFile, []byte("local mode"), 0o600); err != nil {
|
|
t.Fatalf("Failed to create test file: %v", err)
|
|
}
|
|
|
|
// Execute a command that should run from base path
|
|
out, err := client.ExecContext(ctx, "cat local_test.txt")
|
|
if err != nil {
|
|
t.Errorf("ExecContext failed: %v", err)
|
|
}
|
|
|
|
if out != "local mode" {
|
|
t.Errorf("Expected 'local mode', got %q", out)
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_FileExists(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
// Test existing file
|
|
if !client.FileExists("/etc/passwd") {
|
|
t.Error("FileExists should return true for /etc/passwd")
|
|
}
|
|
|
|
// Test non-existing file
|
|
if client.FileExists("/non/existing/file") {
|
|
t.Error("FileExists should return false for non-existing file")
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_GetFileSize(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
size, err := client.GetFileSize("/etc/passwd")
|
|
if err != nil {
|
|
t.Errorf("GetFileSize failed: %v", err)
|
|
}
|
|
|
|
if size <= 0 {
|
|
t.Errorf("Expected positive size for /etc/passwd, got %d", size)
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_ListDir(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
entries := client.ListDir("/etc")
|
|
if entries == nil {
|
|
t.Error("ListDir should return non-nil slice")
|
|
}
|
|
|
|
if !slices.Contains(entries, "passwd") {
|
|
t.Error("ListDir should include 'passwd' in /etc directory")
|
|
}
|
|
}
|
|
|
|
func TestSSHClient_TailFile(t *testing.T) {
|
|
t.Parallel() // Enable parallel execution
|
|
client, err := network.NewSSHClient("", "", "", 0, "")
|
|
if err != nil {
|
|
t.Fatalf("NewSSHClient failed: %v", err)
|
|
}
|
|
defer func() { _ = client.Close() }()
|
|
|
|
content := client.TailFile("/etc/passwd", 5)
|
|
if content == "" {
|
|
t.Error("TailFile should return non-empty content")
|
|
}
|
|
|
|
lines := len(strings.Split(strings.TrimSpace(content), "\n"))
|
|
if lines > 5 {
|
|
t.Errorf("Expected at most 5 lines, got %d", lines)
|
|
}
|
|
}
|