// Package network provides SSH client and retry utilities. package network import ( "context" "fmt" "log" "net" "os" "os/exec" "path/filepath" "strings" "time" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/fileutil" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/knownhosts" ) // SSHClient provides SSH connection and command execution type SSHClient struct { client *ssh.Client host string basePath string } // NewSSHClient creates a new SSH client. If host or keyPath is empty, returns a local-mode client. // knownHostsPath is optional - if provided, will use known_hosts verification func NewSSHClient(host, user, keyPath string, port int, knownHostsPath string) (*SSHClient, error) { if host == "" || keyPath == "" { // Local mode - no SSH connection needed return &SSHClient{client: nil, host: ""}, nil } keyPath = config.ExpandPath(keyPath) if strings.HasPrefix(keyPath, "~") { home, _ := os.UserHomeDir() keyPath = filepath.Join(home, keyPath[1:]) } key, err := fileutil.SecureFileRead(keyPath) if err != nil { return nil, fmt.Errorf("failed to read SSH key: %w", err) } var signer ssh.Signer if signer, err = ssh.ParsePrivateKey(key); err != nil { if _, ok := err.(*ssh.PassphraseMissingError); ok { // Try to use ssh-agent for passphrase-protected keys if agentSigner, agentErr := sshAgentSigner(); agentErr == nil { signer = agentSigner } else { return nil, fmt.Errorf("SSH key is passphrase protected and ssh-agent unavailable: %w", err) } } else { return nil, fmt.Errorf("failed to parse SSH key: %w", err) } } // InsecureIgnoreHostKey used as fallback - security implications reviewed //nolint:gosec // G106: Use of InsecureIgnoreHostKey is intentional fallback hostKeyCallback := ssh.InsecureIgnoreHostKey() if knownHostsPath != "" { knownHostsPath = config.ExpandPath(knownHostsPath) if _, err := os.Stat(knownHostsPath); err == nil { callback, err := knownhosts.New(knownHostsPath) if err != nil { log.Printf("Warning: failed to parse known_hosts: %v; using insecure host key verification", err) } else { hostKeyCallback = callback } } else if !os.IsNotExist(err) { log.Printf("Warning: known_hosts not found at %s; using insecure host key verification", knownHostsPath) } } sshConfig := &ssh.ClientConfig{ User: user, Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, HostKeyCallback: hostKeyCallback, Timeout: 10 * time.Second, HostKeyAlgorithms: []string{ ssh.KeyAlgoRSA, ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512, ssh.KeyAlgoED25519, ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521, }, } addr := fmt.Sprintf("%s:%d", host, port) client, err := ssh.Dial("tcp", addr, sshConfig) if err != nil { return nil, fmt.Errorf("SSH connection failed: %w", err) } return &SSHClient{client: client, host: host}, nil } // NewLocalClient creates a local-mode SSHClient that executes commands on the host using the provided base path. func NewLocalClient(basePath string) *SSHClient { if basePath != "" { basePath = config.ExpandPath(basePath) } return &SSHClient{ client: nil, host: "localhost", basePath: basePath, } } // Exec executes a command remotely via SSH or locally if in local mode func (c *SSHClient) Exec(cmd string) (string, error) { return c.ExecContext(context.Background(), cmd) } // ExecContext executes a command with context support for cancellation and timeouts func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) { if c.client == nil { // Local mode - execute command locally with context execCmd := exec.CommandContext(ctx, "sh", "-c", cmd) if c.basePath != "" { execCmd.Dir = c.basePath } output, err := execCmd.CombinedOutput() return string(output), err } session, err := c.client.NewSession() if err != nil { return "", fmt.Errorf("create session: %w", err) } defer func() { if closeErr := session.Close(); closeErr != nil { // Session may already be closed, so we just log at debug level log.Printf("session close error (may be expected): %v", closeErr) } }() // Run command with context cancellation type result struct { output string err error } resultCh := make(chan result, 1) go func() { output, err := session.CombinedOutput(cmd) resultCh <- result{string(output), err} }() select { case <-ctx.Done(): // FIXED: Check error return value if sigErr := session.Signal(ssh.SIGTERM); sigErr != nil { log.Printf("failed to send SIGTERM: %v", sigErr) } // Give process time to cleanup gracefully timer := time.NewTimer(5 * time.Second) defer timer.Stop() select { case res := <-resultCh: // Command finished during graceful shutdown return res.output, fmt.Errorf("command cancelled: %w (output: %s)", ctx.Err(), res.output) case <-timer.C: if closeErr := session.Close(); closeErr != nil { log.Printf("failed to force close session: %v", closeErr) } // Wait a bit more for final result select { case res := <-resultCh: return res.output, fmt.Errorf("command cancelled and force closed: %w (output: %s)", ctx.Err(), res.output) case <-time.After(5 * time.Second): return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err()) } } case res := <-resultCh: return res.output, res.err } } // FileExists checks if a file exists remotely or locally func (c *SSHClient) FileExists(path string) bool { if c.client == nil { // Local mode - check file locally _, err := os.Stat(path) return err == nil } out, err := c.Exec(fmt.Sprintf("test -e %s && echo 'exists'", path)) if err != nil { return false } return strings.Contains(strings.TrimSpace(out), "exists") } // GetFileSize gets the size of a file or directory remotely or locally func (c *SSHClient) GetFileSize(path string) (int64, error) { if c.client == nil { // Local mode - get size locally var size int64 err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error { if err != nil { return err } size += info.Size() return nil }) return size, err } out, err := c.Exec(fmt.Sprintf("du -sb %s | cut -f1", path)) if err != nil { return 0, err } var size int64 if _, err := fmt.Sscanf(strings.TrimSpace(out), "%d", &size); err != nil { return 0, fmt.Errorf("failed to parse file size from output %q: %w", out, err) } return size, nil } // RemoteExists checks if a remote path exists (alias for FileExists for compatibility) func (c *SSHClient) RemoteExists(path string) bool { return c.FileExists(path) } // ListDir lists directory contents remotely or locally func (c *SSHClient) ListDir(path string) []string { if c.client == nil { // Local mode entries, err := os.ReadDir(path) if err != nil { return nil } var items []string for _, entry := range entries { items = append(items, entry.Name()) } return items } out, err := c.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", path)) if err != nil { return nil } var items []string for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") { if line != "" { items = append(items, line) } } return items } // TailFile gets the last N lines of a file remotely or locally func (c *SSHClient) TailFile(path string, lines int) string { if c.client == nil { // Local mode - read file and return last N lines data, err := fileutil.SecureFileRead(path) if err != nil { return "" } fileLines := strings.Split(string(data), "\n") if len(fileLines) > lines { fileLines = fileLines[len(fileLines)-lines:] } return strings.Join(fileLines, "\n") } out, err := c.Exec(fmt.Sprintf("tail -n %d %s 2>/dev/null", lines, path)) if err != nil { return "" } return out } // Close closes the SSH connection func (c *SSHClient) Close() error { if c.client != nil { return c.client.Close() } return nil } // Host returns the host (localhost for local mode, remote host otherwise) func (c *SSHClient) Host() string { return c.host } // sshAgentSigner attempts to get a signer from ssh-agent func sshAgentSigner() (ssh.Signer, error) { sshAuthSock := os.Getenv("SSH_AUTH_SOCK") if sshAuthSock == "" { return nil, fmt.Errorf("SSH_AUTH_SOCK not set") } conn, err := (&net.Dialer{}).DialContext(context.Background(), "unix", sshAuthSock) if err != nil { return nil, fmt.Errorf("failed to connect to ssh-agent: %w", err) } defer func() { if closeErr := conn.Close(); closeErr != nil { log.Printf("warning: failed to close ssh-agent connection: %v", closeErr) } }() agentClient := agent.NewClient(conn) signers, err := agentClient.Signers() if err != nil { return nil, fmt.Errorf("failed to get signers from ssh-agent: %w", err) } if len(signers) == 0 { return nil, fmt.Errorf("no signers available in ssh-agent") } return signers[0], nil }