fetch_ml/tests/unit/network/ssh_pool_test.go
Jeremie Fraeys c980167041 test: implement comprehensive test suite with multiple test types
- Add end-to-end tests for complete workflow validation
- Include integration tests for API and database interactions
- Add unit tests for all major components and utilities
- Include performance tests for payload handling
- Add CLI API integration tests
- Include Podman container integration tests
- Add WebSocket and queue execution tests
- Include shell script tests for setup validation

Provides comprehensive test coverage ensuring platform reliability
and functionality across all components and interactions.
2025-12-04 16:55:13 -05:00

120 lines
2.5 KiB
Go

package tests
import (
"context"
"log/slog"
"sync/atomic"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/network"
)
func newTestLogger() *logging.Logger {
return logging.NewLogger(slog.LevelInfo, false)
}
func TestSSHPool_GetBlocksUntilConnectionReturned(t *testing.T) {
logger := newTestLogger()
maxConns := 2
created := atomic.Int32{}
p := network.NewSSHPool(maxConns, func() (*network.SSHClient, error) {
created.Add(1)
return &network.SSHClient{}, nil
}, logger)
t.Cleanup(p.Close)
ctx := context.Background()
conn1, err := p.Get(ctx)
if err != nil {
t.Fatalf("first Get failed: %v", err)
}
conn2, err := p.Get(ctx)
if err != nil {
t.Fatalf("second Get failed: %v", err)
}
if got := created.Load(); got != int32(maxConns) {
t.Fatalf("expected %d creations, got %d", maxConns, got)
}
blocked := make(chan error, 1)
go func() {
conn, err := p.Get(ctx)
if err == nil && conn != nil {
p.Put(conn)
}
blocked <- err
}()
select {
case err := <-blocked:
t.Fatalf("expected call to block, got err=%v", err)
case <-time.After(50 * time.Millisecond):
}
p.Put(conn1)
select {
case err := <-blocked:
if err != nil {
t.Fatalf("blocked Get returned error: %v", err)
}
case <-time.After(time.Second):
t.Fatal("expected blocked Get to proceed after Put")
}
p.Put(conn2)
}
func TestSSHPool_GetReturnsContextErrorWhenWaiting(t *testing.T) {
logger := newTestLogger()
p := network.NewSSHPool(1, func() (*network.SSHClient, error) {
return &network.SSHClient{}, nil
}, logger)
t.Cleanup(p.Close)
ctx := context.Background()
conn, err := p.Get(ctx)
if err != nil {
t.Fatalf("initial Get failed: %v", err)
}
waitCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = p.Get(waitCtx)
if err != context.DeadlineExceeded {
t.Fatalf("expected deadline exceeded, got %v", err)
}
p.Put(conn)
}
func TestSSHPool_ReusesReturnedConnections(t *testing.T) {
logger := newTestLogger()
p := network.NewSSHPool(1, func() (*network.SSHClient, error) {
return &network.SSHClient{}, nil
}, logger)
t.Cleanup(p.Close)
ctx := context.Background()
conn, err := p.Get(ctx)
if err != nil {
t.Fatalf("first Get failed: %v", err)
}
p.Put(conn)
conn2, err := p.Get(ctx)
if err != nil {
t.Fatalf("second Get failed: %v", err)
}
if conn2 != conn {
t.Fatalf("expected pooled connection reuse, got different pointer")
}
p.Put(conn2)
}