fetch_ml/tests/unit/network/ssh_test.go
Jeremie Fraeys ea15af1833 Fix multi-user authentication and clean up debug code
- Fix YAML tags in auth config struct (json -> yaml)
- Update CLI configs to use pre-hashed API keys
- Remove double hashing in WebSocket client
- Fix port mapping (9102 -> 9103) in CLI commands
- Update permission keys to use jobs:read, jobs:create, etc.
- Clean up all debug logging from CLI and server
- All user roles now authenticate correctly:
  * Admin: Can queue jobs and see all jobs
  * Researcher: Can queue jobs and see own jobs
  * Analyst: Can see status (read-only access)

Multi-user authentication is now fully functional.
2025-12-06 12:35:32 -05:00

261 lines
6.8 KiB
Go

package tests
import (
"context"
"errors"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/network"
)
func TestSSHClient_ExecContext(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // Reduced from 5 seconds
defer cancel()
out, err := client.ExecContext(ctx, "echo 'test'")
if err != nil {
t.Errorf("ExecContext failed: %v", err)
}
if out != "test\n" {
t.Errorf("Expected 'test\\n', got %q", out)
}
}
func TestSSHClient_RemoteExists(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
dir := t.TempDir()
file := filepath.Join(dir, "exists.txt")
if writeErr := os.WriteFile(file, []byte("data"), 0o600); writeErr != nil {
t.Fatalf("failed to create temp file: %v", writeErr)
}
if !client.RemoteExists(file) {
t.Fatal("expected RemoteExists to return true for existing file")
}
missing := filepath.Join(dir, "missing.txt")
if client.RemoteExists(missing) {
t.Fatal("expected RemoteExists to return false for missing file")
}
}
func TestSSHClient_GetFileSizeError(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
if _, err := client.GetFileSize("/path/that/does/not/exist"); err == nil {
t.Fatal("expected GetFileSize to error for missing path")
}
}
func TestSSHClient_TailFileMissingReturnsEmpty(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
if out := client.TailFile("/path/that/does/not/exist", 5); out != "" {
t.Fatalf("expected empty TailFile output for missing file, got %q", out)
}
}
func TestSSHClient_ExecContextCancellationDuringRun(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
_, runErr := client.ExecContext(ctx, "sleep 5")
done <- runErr
}()
time.Sleep(100 * time.Millisecond)
cancel()
select {
case err := <-done:
if err == nil {
t.Fatal("expected cancellation error, got nil")
}
if !errors.Is(err, context.Canceled) && !strings.Contains(err.Error(), "signal: killed") {
t.Fatalf("expected context cancellation or killed signal, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("ExecContext did not return after cancellation")
}
}
func TestSSHClient_ContextCancellation(t *testing.T) {
t.Parallel() // Enable parallel execution
client, _ := network.NewSSHClient("", "", "", 0, "")
defer func() { _ = client.Close() }()
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := client.ExecContext(ctx, "sleep 10")
if err == nil {
t.Error("Expected error from cancelled context")
}
// Check that it's a context cancellation error
if !strings.Contains(err.Error(), "context canceled") {
t.Errorf("Expected context cancellation error, got: %v", err)
}
}
func TestSSHClient_LocalMode(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
// Test basic command
out, err := client.Exec("pwd")
if err != nil {
t.Errorf("Exec failed: %v", err)
}
if out == "" {
t.Error("Expected non-empty output from pwd")
}
}
func TestSSHClient_NewLocalClient(t *testing.T) {
t.Parallel()
basePath := t.TempDir()
client := network.NewLocalClient(basePath)
defer func() { _ = client.Close() }()
// Verify client is in local mode
if client.Host() != "localhost" {
t.Errorf("Expected host 'localhost', got %q", client.Host())
}
// Test that commands execute in the base path
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// Create a test file in the base path
testFile := filepath.Join(basePath, "local_test.txt")
if err := os.WriteFile(testFile, []byte("local mode"), 0o600); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
// Execute a command that should run from base path
out, err := client.ExecContext(ctx, "cat local_test.txt")
if err != nil {
t.Errorf("ExecContext failed: %v", err)
}
if out != "local mode" {
t.Errorf("Expected 'local mode', got %q", out)
}
}
func TestSSHClient_FileExists(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
// Test existing file
if !client.FileExists("/etc/passwd") {
t.Error("FileExists should return true for /etc/passwd")
}
// Test non-existing file
if client.FileExists("/non/existing/file") {
t.Error("FileExists should return false for non-existing file")
}
}
func TestSSHClient_GetFileSize(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
size, err := client.GetFileSize("/etc/passwd")
if err != nil {
t.Errorf("GetFileSize failed: %v", err)
}
if size <= 0 {
t.Errorf("Expected positive size for /etc/passwd, got %d", size)
}
}
func TestSSHClient_ListDir(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
entries := client.ListDir("/etc")
if entries == nil {
t.Error("ListDir should return non-nil slice")
}
if !slices.Contains(entries, "passwd") {
t.Error("ListDir should include 'passwd' in /etc directory")
}
}
func TestSSHClient_TailFile(t *testing.T) {
t.Parallel() // Enable parallel execution
client, err := network.NewSSHClient("", "", "", 0, "")
if err != nil {
t.Fatalf("NewSSHClient failed: %v", err)
}
defer func() { _ = client.Close() }()
content := client.TailFile("/etc/passwd", 5)
if content == "" {
t.Error("TailFile should return non-empty content")
}
lines := len(strings.Split(strings.TrimSpace(content), "\n"))
if lines > 5 {
t.Errorf("Expected at most 5 lines, got %d", lines)
}
}