package tests import ( "context" "log/slog" "sync/atomic" "testing" "time" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/network" ) func newTestLogger() *logging.Logger { return logging.NewLogger(slog.LevelInfo, false) } func TestSSHPool_GetBlocksUntilConnectionReturned(t *testing.T) { logger := newTestLogger() maxConns := 2 created := atomic.Int32{} p := network.NewSSHPool(maxConns, func() (*network.SSHClient, error) { created.Add(1) return &network.SSHClient{}, nil }, logger) t.Cleanup(p.Close) ctx := context.Background() conn1, err := p.Get(ctx) if err != nil { t.Fatalf("first Get failed: %v", err) } conn2, err := p.Get(ctx) if err != nil { t.Fatalf("second Get failed: %v", err) } if got := created.Load(); got != int32(maxConns) { t.Fatalf("expected %d creations, got %d", maxConns, got) } blocked := make(chan error, 1) go func() { conn, err := p.Get(ctx) if err == nil && conn != nil { p.Put(conn) } blocked <- err }() select { case err := <-blocked: t.Fatalf("expected call to block, got err=%v", err) case <-time.After(50 * time.Millisecond): } p.Put(conn1) select { case err := <-blocked: if err != nil { t.Fatalf("blocked Get returned error: %v", err) } case <-time.After(time.Second): t.Fatal("expected blocked Get to proceed after Put") } p.Put(conn2) } func TestSSHPool_GetReturnsContextErrorWhenWaiting(t *testing.T) { logger := newTestLogger() p := network.NewSSHPool(1, func() (*network.SSHClient, error) { return &network.SSHClient{}, nil }, logger) t.Cleanup(p.Close) ctx := context.Background() conn, err := p.Get(ctx) if err != nil { t.Fatalf("initial Get failed: %v", err) } waitCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() _, err = p.Get(waitCtx) if err != context.DeadlineExceeded { t.Fatalf("expected deadline exceeded, got %v", err) } p.Put(conn) } func TestSSHPool_ReusesReturnedConnections(t *testing.T) { logger := newTestLogger() p := network.NewSSHPool(1, func() (*network.SSHClient, error) { return &network.SSHClient{}, nil }, logger) t.Cleanup(p.Close) ctx := context.Background() conn, err := p.Get(ctx) if err != nil { t.Fatalf("first Get failed: %v", err) } p.Put(conn) conn2, err := p.Get(ctx) if err != nil { t.Fatalf("second Get failed: %v", err) } if conn2 != conn { t.Fatalf("expected pooled connection reuse, got different pointer") } p.Put(conn2) }