- Add API server with WebSocket support and REST endpoints - Implement authentication system with API keys and permissions - Add task queue system with Redis backend and error handling - Include storage layer with database migrations and schemas - Add comprehensive logging, metrics, and telemetry - Implement security middleware and network utilities - Add experiment management and container orchestration - Include configuration management with smart defaults
304 lines
8 KiB
Go
304 lines
8 KiB
Go
// Package utils provides shared utilities for the fetch_ml project.
|
|
package network
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
|
)
|
|
|
|
// SSHClient provides SSH connection and command execution
|
|
type SSHClient struct {
|
|
client *ssh.Client
|
|
host string
|
|
}
|
|
|
|
// NewSSHClient creates a new SSH client. If host or keyPath is empty, returns a local-mode client.
|
|
// knownHostsPath is optional - if provided, will use known_hosts verification
|
|
func NewSSHClient(host, user, keyPath string, port int, knownHostsPath string) (*SSHClient, error) {
|
|
if host == "" || keyPath == "" {
|
|
// Local mode - no SSH connection needed
|
|
return &SSHClient{client: nil, host: ""}, nil
|
|
}
|
|
|
|
keyPath = config.ExpandPath(keyPath)
|
|
if strings.HasPrefix(keyPath, "~") {
|
|
home, _ := os.UserHomeDir()
|
|
keyPath = filepath.Join(home, keyPath[1:])
|
|
}
|
|
|
|
key, err := os.ReadFile(keyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read SSH key: %w", err)
|
|
}
|
|
|
|
var signer ssh.Signer
|
|
if signer, err = ssh.ParsePrivateKey(key); err != nil {
|
|
if _, ok := err.(*ssh.PassphraseMissingError); ok {
|
|
// Try to use ssh-agent for passphrase-protected keys
|
|
if agentSigner, agentErr := sshAgentSigner(); agentErr == nil {
|
|
signer = agentSigner
|
|
} else {
|
|
return nil, fmt.Errorf("SSH key is passphrase protected and ssh-agent unavailable: %w", err)
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("failed to parse SSH key: %w", err)
|
|
}
|
|
}
|
|
|
|
hostKeyCallback := ssh.InsecureIgnoreHostKey()
|
|
if knownHostsPath != "" {
|
|
knownHostsPath = config.ExpandPath(knownHostsPath)
|
|
if _, err := os.Stat(knownHostsPath); err == nil {
|
|
callback, err := knownhosts.New(knownHostsPath)
|
|
if err != nil {
|
|
log.Printf("Warning: failed to parse known_hosts: %v; using insecure host key verification", err)
|
|
} else {
|
|
hostKeyCallback = callback
|
|
}
|
|
} else if !os.IsNotExist(err) {
|
|
log.Printf("Warning: known_hosts not found at %s; using insecure host key verification", knownHostsPath)
|
|
}
|
|
}
|
|
|
|
sshConfig := &ssh.ClientConfig{
|
|
User: user,
|
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
|
HostKeyCallback: hostKeyCallback,
|
|
Timeout: 10 * time.Second,
|
|
HostKeyAlgorithms: []string{
|
|
ssh.KeyAlgoRSA,
|
|
ssh.KeyAlgoRSASHA256,
|
|
ssh.KeyAlgoRSASHA512,
|
|
ssh.KeyAlgoED25519,
|
|
ssh.KeyAlgoECDSA256,
|
|
ssh.KeyAlgoECDSA384,
|
|
ssh.KeyAlgoECDSA521,
|
|
},
|
|
}
|
|
|
|
addr := fmt.Sprintf("%s:%d", host, port)
|
|
client, err := ssh.Dial("tcp", addr, sshConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("SSH connection failed: %w", err)
|
|
}
|
|
|
|
return &SSHClient{client: client, host: host}, nil
|
|
}
|
|
|
|
// Exec executes a command remotely via SSH or locally if in local mode
|
|
func (c *SSHClient) Exec(cmd string) (string, error) {
|
|
return c.ExecContext(context.Background(), cmd)
|
|
}
|
|
|
|
// ExecContext executes a command with context support for cancellation and timeouts
|
|
func (c *SSHClient) ExecContext(ctx context.Context, cmd string) (string, error) {
|
|
if c.client == nil {
|
|
// Local mode - execute command locally with context
|
|
execCmd := exec.CommandContext(ctx, "sh", "-c", cmd)
|
|
output, err := execCmd.CombinedOutput()
|
|
return string(output), err
|
|
}
|
|
|
|
session, err := c.client.NewSession()
|
|
if err != nil {
|
|
return "", fmt.Errorf("create session: %w", err)
|
|
}
|
|
defer func() {
|
|
if closeErr := session.Close(); closeErr != nil {
|
|
// Session may already be closed, so we just log at debug level
|
|
log.Printf("session close error (may be expected): %v", closeErr)
|
|
}
|
|
}()
|
|
|
|
// Run command with context cancellation
|
|
type result struct {
|
|
output string
|
|
err error
|
|
}
|
|
resultCh := make(chan result, 1)
|
|
|
|
go func() {
|
|
output, err := session.CombinedOutput(cmd)
|
|
resultCh <- result{string(output), err}
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
// FIXED: Check error return value
|
|
if sigErr := session.Signal(ssh.SIGTERM); sigErr != nil {
|
|
log.Printf("failed to send SIGTERM: %v", sigErr)
|
|
}
|
|
|
|
// Give process time to cleanup gracefully
|
|
timer := time.NewTimer(5 * time.Second)
|
|
defer timer.Stop()
|
|
|
|
select {
|
|
case res := <-resultCh:
|
|
// Command finished during graceful shutdown
|
|
return res.output, fmt.Errorf("command cancelled: %w (output: %s)", ctx.Err(), res.output)
|
|
case <-timer.C:
|
|
if closeErr := session.Close(); closeErr != nil {
|
|
log.Printf("failed to force close session: %v", closeErr)
|
|
}
|
|
|
|
// Wait a bit more for final result
|
|
select {
|
|
case res := <-resultCh:
|
|
return res.output, fmt.Errorf("command cancelled and force closed: %w (output: %s)", ctx.Err(), res.output)
|
|
case <-time.After(5 * time.Second):
|
|
return "", fmt.Errorf("command cancelled and cleanup timeout: %w", ctx.Err())
|
|
}
|
|
}
|
|
case res := <-resultCh:
|
|
return res.output, res.err
|
|
}
|
|
}
|
|
|
|
// FileExists checks if a file exists remotely or locally
|
|
func (c *SSHClient) FileExists(path string) bool {
|
|
if c.client == nil {
|
|
// Local mode - check file locally
|
|
_, err := os.Stat(path)
|
|
return err == nil
|
|
}
|
|
|
|
out, err := c.Exec(fmt.Sprintf("test -e %s && echo 'exists'", path))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return strings.Contains(strings.TrimSpace(out), "exists")
|
|
}
|
|
|
|
// GetFileSize gets the size of a file or directory remotely or locally
|
|
func (c *SSHClient) GetFileSize(path string) (int64, error) {
|
|
if c.client == nil {
|
|
// Local mode - get size locally
|
|
var size int64
|
|
err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
size += info.Size()
|
|
return nil
|
|
})
|
|
return size, err
|
|
}
|
|
|
|
out, err := c.Exec(fmt.Sprintf("du -sb %s | cut -f1", path))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
var size int64
|
|
if _, err := fmt.Sscanf(strings.TrimSpace(out), "%d", &size); err != nil {
|
|
return 0, fmt.Errorf("failed to parse file size from output %q: %w", out, err)
|
|
}
|
|
return size, nil
|
|
}
|
|
|
|
// RemoteExists checks if a remote path exists (alias for FileExists for compatibility)
|
|
func (c *SSHClient) RemoteExists(path string) bool {
|
|
return c.FileExists(path)
|
|
}
|
|
|
|
// ListDir lists directory contents remotely or locally
|
|
func (c *SSHClient) ListDir(path string) []string {
|
|
if c.client == nil {
|
|
// Local mode
|
|
entries, err := os.ReadDir(path)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
var items []string
|
|
for _, entry := range entries {
|
|
items = append(items, entry.Name())
|
|
}
|
|
return items
|
|
}
|
|
|
|
out, err := c.Exec(fmt.Sprintf("ls -1 %s 2>/dev/null", path))
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
var items []string
|
|
for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
|
|
if line != "" {
|
|
items = append(items, line)
|
|
}
|
|
}
|
|
return items
|
|
}
|
|
|
|
// TailFile gets the last N lines of a file remotely or locally
|
|
func (c *SSHClient) TailFile(path string, lines int) string {
|
|
if c.client == nil {
|
|
// Local mode - read file and return last N lines
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
fileLines := strings.Split(string(data), "\n")
|
|
if len(fileLines) > lines {
|
|
fileLines = fileLines[len(fileLines)-lines:]
|
|
}
|
|
return strings.Join(fileLines, "\n")
|
|
}
|
|
|
|
out, err := c.Exec(fmt.Sprintf("tail -n %d %s 2>/dev/null", lines, path))
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return out
|
|
}
|
|
|
|
// Close closes the SSH connection
|
|
func (c *SSHClient) Close() error {
|
|
if c.client != nil {
|
|
return c.client.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// sshAgentSigner attempts to get a signer from ssh-agent
|
|
func sshAgentSigner() (ssh.Signer, error) {
|
|
sshAuthSock := os.Getenv("SSH_AUTH_SOCK")
|
|
if sshAuthSock == "" {
|
|
return nil, fmt.Errorf("SSH_AUTH_SOCK not set")
|
|
}
|
|
|
|
conn, err := net.Dial("unix", sshAuthSock)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to ssh-agent: %w", err)
|
|
}
|
|
defer func() {
|
|
if closeErr := conn.Close(); closeErr != nil {
|
|
log.Printf("warning: failed to close ssh-agent connection: %v", closeErr)
|
|
}
|
|
}()
|
|
|
|
agentClient := agent.NewClient(conn)
|
|
signers, err := agentClient.Signers()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get signers from ssh-agent: %w", err)
|
|
}
|
|
|
|
if len(signers) == 0 {
|
|
return nil, fmt.Errorf("no signers available in ssh-agent")
|
|
}
|
|
|
|
return signers[0], nil
|
|
}
|