fetch_ml/internal/network/ssh.go

339 lines
8.8 KiB
Go

// Package network provides SSH client and retry utilities.
package network
import (
"context"
"fmt"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)
// SSHClient provides SSH connection and command execution
type SSHClient struct {
client *ssh.Client
host string
basePath string
}
// NewSSHClient creates a new SSH client. If host or keyPath is empty, returns a local-mode client.
// knownHostsPath is optional - if provided, will use known_hosts verification
func NewSSHClient(host, user, keyPath string, port int, knownHostsPath string) (*SSHClient, error) {
if host == "" || keyPath == "" {
// Local mode - no SSH connection needed
return &SSHClient{client: nil, host: ""}, nil
}
keyPath = config.ExpandPath(keyPath)
if strings.HasPrefix(keyPath, "~") {
home, _ := os.UserHomeDir()
keyPath = filepath.Join(home, keyPath[1:])
}
key, err := fileutil.SecureFileRead(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to read SSH key: %w", err)
}
var signer ssh.Signer
if signer, err = ssh.ParsePrivateKey(key); err != nil {
if _, ok := err.(*ssh.PassphraseMissingError); ok {
// Try to use ssh-agent for passphrase-protected keys
if agentSigner, agentErr := sshAgentSigner(); agentErr == nil {
signer = agentSigner
} else {
return nil, fmt.Errorf("SSH key is passphrase protected and ssh-agent unavailable: %w", err)
}
} else {
return nil, fmt.Errorf("failed to parse SSH key: %w", err)
}
}
// InsecureIgnoreHostKey used as fallback - security implications reviewed
//nolint:gosec // G106: Use of InsecureIgnoreHostKey is intentional fallback
hostKeyCallback := ssh.InsecureIgnoreHostKey()
if knownHostsPath != "" {
knownHostsPath = config.ExpandPath(knownHostsPath)
if _, err := os.Stat(knownHostsPath); err == nil {
callback, err := knownhosts.New(knownHostsPath)
if err != nil {
log.Printf(
"Warning: failed to parse known_hosts: %v; using insecure host key verification",
err,
)
} else {
hostKeyCallback = callback
}
} else if !os.IsNotExist(err) {
log.Printf(
"Warning: known_hosts not found at %s; using insecure host key verification",
knownHostsPath,
)
}
}
sshConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: hostKeyCallback,
Timeout: 10 * time.Second,
HostKeyAlgorithms: []string{
ssh.KeyAlgoRSA,
ssh.KeyAlgoRSASHA256,
ssh.KeyAlgoRSASHA512,
ssh.KeyAlgoED25519,
ssh.KeyAlgoECDSA256,
ssh.KeyAlgoECDSA384,
ssh.KeyAlgoECDSA521,
},
}
addr := fmt.Sprintf("%s:%d", host, port)
client, err := ssh.Dial("tcp", addr, sshConfig)
if err != nil {
return nil, fmt.Errorf("SSH connection failed: %w", err)
}
return &SSHClient{client: client, host: host}, nil
}
// NewLocalClient creates a local-mode SSHClient that executes commands on the host.
func NewLocalClient(basePath string) *SSHClient {
if basePath != "" {
basePath = config.ExpandPath(basePath)
}
return &SSHClient{
client: nil,
host: "localhost",
basePath: basePath,
}
}
// Exec executes a command remotely via SSH or locally if in local mode
func (c *SSHClient) Exec(cmd string) (string, error) {
return c.ExecContext(context.Background(), cmd)
}
// ExecContext executes a command with context support for cancellation and timeouts
func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) {
if c.client == nil {
// Local mode - execute command locally with context
execCmd := exec.CommandContext(ctx, "sh", "-c", cmd)
if c.basePath != "" {
execCmd.Dir = c.basePath
}
output, err := execCmd.CombinedOutput()
return string(output), err
}
session, err := c.client.NewSession()
if err != nil {
return "", fmt.Errorf("create session: %w", err)
}
defer func() {
if closeErr := session.Close(); closeErr != nil {
// Session may already be closed, so we just log at debug level
log.Printf("session close error (may be expected): %v", closeErr)
}
}()
// Run command with context cancellation
type result struct {
output string
err error
}
resultCh := make(chan result, 1)
go func() {
output, err := session.CombinedOutput(cmd)
resultCh <- result{string(output), err}
}()
select {
case <-ctx.Done():
// FIXED: Check error return value
if sigErr := session.Signal(ssh.SIGTERM); sigErr != nil {
log.Printf("failed to send SIGTERM: %v", sigErr)
}
// Give process time to cleanup gracefully
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
select {
case res := <-resultCh:
// Command finished during graceful shutdown
return res.output, fmt.Errorf("command cancelled: %w (output: %s)", ctx.Err(), res.output)
case <-timer.C:
if closeErr := session.Close(); closeErr != nil {
log.Printf("failed to force close session: %v", closeErr)
}
// Wait a bit more for final result
select {
case res := <-resultCh:
return res.output, fmt.Errorf(
"command cancelled and force closed: %w (output: %s)",
ctx.Err(),
res.output,
)
case <-time.After(5 * time.Second):
return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err())
}
}
case res := <-resultCh:
return res.output, res.err
}
}
// FileExists checks if a file exists remotely or locally
func (c *SSHClient) FileExists(path string) bool {
if c.client == nil {
// Local mode - check file locally
_, err := os.Stat(path)
return err == nil
}
out, err := c.Exec(fmt.Sprintf("test -e %s && echo 'exists'", path))
if err != nil {
return false
}
return strings.Contains(strings.TrimSpace(out), "exists")
}
// GetFileSize gets the size of a file or directory remotely or locally
func (c *SSHClient) GetFileSize(path string) (int64, error) {
if c.client == nil {
// Local mode - get size locally
var size int64
err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
size += info.Size()
return nil
})
return size, err
}
out, err := c.Exec(fmt.Sprintf("du -sb %s | cut -f1", path))
if err != nil {
return 0, err
}
var size int64
if _, err := fmt.Sscanf(strings.TrimSpace(out), "%d", &size); err != nil {
return 0, fmt.Errorf("failed to parse file size from output %q: %w", out, err)
}
return size, nil
}
// RemoteExists checks if a remote path exists (alias for FileExists for compatibility)
func (c *SSHClient) RemoteExists(path string) bool {
return c.FileExists(path)
}
// ListDir lists directory contents remotely or locally
func (c *SSHClient) ListDir(path string) []string {
if c.client == nil {
// Local mode
entries, err := os.ReadDir(path)
if err != nil {
return nil
}
var items []string
for _, entry := range entries {
items = append(items, entry.Name())
}
return items
}
out, err := c.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", path))
if err != nil {
return nil
}
var items []string
for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
if line != "" {
items = append(items, line)
}
}
return items
}
// TailFile gets the last N lines of a file remotely or locally
func (c *SSHClient) TailFile(path string, lines int) string {
if c.client == nil {
// Local mode - read file and return last N lines
data, err := fileutil.SecureFileRead(path)
if err != nil {
return ""
}
fileLines := strings.Split(string(data), "\n")
if len(fileLines) > lines {
fileLines = fileLines[len(fileLines)-lines:]
}
return strings.Join(fileLines, "\n")
}
out, err := c.Exec(fmt.Sprintf("tail -n %d %s 2>/dev/null", lines, path))
if err != nil {
return ""
}
return out
}
// Close closes the SSH connection
func (c *SSHClient) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
// Host returns the host (localhost for local mode, remote host otherwise)
func (c *SSHClient) Host() string {
return c.host
}
// sshAgentSigner attempts to get a signer from ssh-agent
func sshAgentSigner() (ssh.Signer, error) {
sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
if sshAuthSock == "" {
return nil, fmt.Errorf("SSH_AUTH_SOCK not set")
}
conn, err := (&net.Dialer{}).DialContext(context.Background(), "unix", sshAuthSock)
if err != nil {
return nil, fmt.Errorf("failed to connect to ssh-agent: %w", err)
}
defer func() {
if closeErr := conn.Close(); closeErr != nil {
log.Printf("warning: failed to close ssh-agent connection: %v", closeErr)
}
}()
agentClient := agent.NewClient(conn)
signers, err := agentClient.Signers()
if err != nil {
return nil, fmt.Errorf("failed to get signers from ssh-agent: %w", err)
}
if len(signers) == 0 {
return nil, fmt.Errorf("no signers available in ssh-agent")
}
return signers[0], nil
}