fetch_ml/internal/worker/process/network_policy.go
Jeremie Fraeys 95adcba437
feat(worker): add Jupyter/vLLM plugins and process isolation
Extend worker capabilities with new execution plugins and security features:
- Jupyter plugin for notebook-based ML experiments
- vLLM plugin for LLM inference workloads
- Cross-platform process isolation (Unix/Windows)
- Network policy enforcement with platform-specific implementations
- Service manager integration for lifecycle management
- Scheduler backend integration for queue coordination

Update lifecycle management:
- Enhanced runloop with state transitions
- Service manager integration for plugin coordination
- Improved state persistence and recovery

Add test coverage:
- Unit tests for Jupyter and vLLM plugins
- Updated worker execution tests
2026-02-26 12:03:59 -05:00

278 lines
8.2 KiB
Go

// Package process provides process isolation and security enforcement for worker tasks.
// This file implements Network Micro-Segmentation enforcement hooks.
//go:build linux
// +build linux
package process
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
)
// NetworkPolicy defines network segmentation rules for a task
type NetworkPolicy struct {
// Mode is the network isolation mode: "none", "bridge", "container", "host"
Mode string
// AllowedEndpoints is a list of allowed network endpoints (host:port format)
// Only used when Mode is "bridge" or "container"
AllowedEndpoints []string
// BlockedSubnets is a list of CIDR ranges to block
BlockedSubnets []string
// DNSResolution controls DNS resolution (true = allow, false = block)
DNSResolution bool
// OutboundTraffic controls outbound connections (true = allow, false = block)
OutboundTraffic bool
// InboundTraffic controls inbound connections (true = allow, false = block)
InboundTraffic bool
}
// DefaultNetworkPolicy returns a hardened default network policy
// This implements Network Micro-Segmentation
func DefaultNetworkPolicy() NetworkPolicy {
return NetworkPolicy{
Mode: "none",
AllowedEndpoints: []string{},
BlockedSubnets: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"},
DNSResolution: false,
OutboundTraffic: false,
InboundTraffic: false,
}
}
// HIPAACompliantPolicy returns a network policy suitable for HIPAA compliance
// This blocks all external network access except specific allowlisted endpoints
func HIPAACompliantPolicy(allowlist []string) NetworkPolicy {
return NetworkPolicy{
Mode: "bridge",
AllowedEndpoints: allowlist,
BlockedSubnets: []string{"0.0.0.0/0"}, // Block everything by default
DNSResolution: len(allowlist) > 0, // Only allow DNS if endpoints specified
OutboundTraffic: len(allowlist) > 0, // Only allow outbound if endpoints specified
InboundTraffic: false, // Never allow inbound
}
}
// Validate checks the network policy for security violations
func (np *NetworkPolicy) Validate() error {
// Validate mode
validModes := map[string]bool{
"none": true,
"bridge": true,
"container": true,
"host": true,
}
if !validModes[np.Mode] {
return fmt.Errorf("invalid network mode: %q", np.Mode)
}
// Block host network mode in production
if np.Mode == "host" {
return fmt.Errorf("host network mode is not allowed for security reasons")
}
// Validate allowed endpoints format
for _, endpoint := range np.AllowedEndpoints {
if !isValidEndpoint(endpoint) {
return fmt.Errorf("invalid endpoint format: %q (expected host:port)", endpoint)
}
}
// Validate CIDR blocks
for _, cidr := range np.BlockedSubnets {
if !isValidCIDR(cidr) {
return fmt.Errorf("invalid CIDR format: %q", cidr)
}
}
return nil
}
// isValidEndpoint checks if an endpoint string is valid (host:port format)
func isValidEndpoint(endpoint string) bool {
if endpoint == "" {
return false
}
parts := strings.Split(endpoint, ":")
if len(parts) != 2 {
return false
}
// Basic validation - port should be numeric
if _, err := parsePort(parts[1]); err != nil {
return false
}
return true
}
// isValidCIDR performs basic CIDR validation
func isValidCIDR(cidr string) bool {
// Simple validation - check for / separator
if !strings.Contains(cidr, "/") {
return false
}
parts := strings.Split(cidr, "/")
if len(parts) != 2 {
return false
}
// Check prefix is numeric
if _, err := parsePort(parts[1]); err != nil {
return false
}
return true
}
// parsePort parses a port string (helper for validation)
func parsePort(s string) (int, error) {
port := 0
for _, c := range s {
if c < '0' || c > '9' {
return 0, fmt.Errorf("invalid port")
}
port = port*10 + int(c-'0')
}
return port, nil
}
// ApplyNetworkPolicy applies network policy enforcement to a podman command
// This creates iptables rules and returns the modified command with network options
func ApplyNetworkPolicy(policy NetworkPolicy, baseArgs []string) ([]string, error) {
if err := policy.Validate(); err != nil {
return nil, fmt.Errorf("invalid network policy: %w", err)
}
// Apply network mode
args := append(baseArgs, "--network", policy.Mode)
// For bridge mode with specific restrictions, we need to create a custom network
if policy.Mode == "bridge" && len(policy.AllowedEndpoints) > 0 {
// Add additional network restrictions via iptables (applied externally)
// The container will be started with the bridge network, but external
// firewall rules will restrict its connectivity
// Set environment variables to inform the container about network restrictions
args = append(args, "-e", "FETCHML_NETWORK_RESTRICTED=1")
if !policy.DNSResolution {
args = append(args, "-e", "FETCHML_DNS_DISABLED=1")
}
}
// Disable DNS if required (via /etc/resolv.conf bind mount)
if !policy.DNSResolution {
// Mount empty resolv.conf to disable DNS
emptyResolv, err := createEmptyResolvConf()
if err == nil {
args = append(args, "-v", fmt.Sprintf("%s:/etc/resolv.conf:ro", emptyResolv))
}
}
return args, nil
}
// createEmptyResolvConf creates a temporary empty resolv.conf file
func createEmptyResolvConf() (string, error) {
tmpDir := os.TempDir()
path := filepath.Join(tmpDir, "empty-resolv.conf")
// Create empty file if it doesn't exist
if _, err := os.Stat(path); os.IsNotExist(err) {
if err := os.WriteFile(path, []byte{}, 0644); err != nil {
return "", err
}
}
return path, nil
}
// SetupExternalFirewall sets up external firewall rules for a container
// This is called after the container starts to enforce egress filtering
// NOTE: This requires root or CAP_NET_ADMIN capability
func SetupExternalFirewall(containerID string, policy NetworkPolicy) error {
// This function requires root privileges and iptables
// It's meant to be called from a privileged helper or init container
if len(policy.BlockedSubnets) == 0 && len(policy.AllowedEndpoints) == 0 {
return nil // No rules to apply
}
// Get container PID for network namespace targeting
pid, err := getContainerPID(containerID)
if err != nil {
return fmt.Errorf("failed to get container PID: %w", err)
}
// Create iptables commands in the container's network namespace
// This requires nsenter with appropriate capabilities
// Block all outbound traffic by default
if !policy.OutboundTraffic {
cmd := exec.Command("nsenter", "-t", pid, "-n", "iptables", "-A", "OUTPUT", "-j", "DROP")
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to block outbound traffic: %w", err)
}
}
// Allow specific endpoints
for _, endpoint := range policy.AllowedEndpoints {
host, port := parseEndpoint(endpoint)
if host != "" {
cmd := exec.Command("nsenter", "-t", pid, "-n", "iptables", "-I", "OUTPUT", "1",
"-p", "tcp", "-d", host, "--dport", port, "-j", "ACCEPT")
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to allow endpoint %s: %w", endpoint, err)
}
}
}
return nil
}
// getContainerPID retrieves the PID of a running container
func getContainerPID(containerID string) (string, error) {
cmd := exec.Command("podman", "inspect", "-f", "{{.State.Pid}}", containerID)
output, err := cmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(output)), nil
}
// parseEndpoint splits an endpoint string into host and port
func parseEndpoint(endpoint string) (host, port string) {
parts := strings.Split(endpoint, ":")
if len(parts) == 2 {
return parts[0], parts[1]
}
return "", ""
}
// NetworkPolicyFromSandbox creates a NetworkPolicy from sandbox configuration
func NetworkPolicyFromSandbox(
networkMode string,
allowedEndpoints []string,
blockedSubnets []string,
) NetworkPolicy {
// Use defaults if not specified
if networkMode == "" {
networkMode = "none"
}
if len(blockedSubnets) == 0 {
blockedSubnets = DefaultNetworkPolicy().BlockedSubnets
}
return NetworkPolicy{
Mode: networkMode,
AllowedEndpoints: allowedEndpoints,
BlockedSubnets: blockedSubnets,
DNSResolution: networkMode != "none" && len(allowedEndpoints) > 0,
OutboundTraffic: networkMode != "none" && len(allowedEndpoints) > 0,
InboundTraffic: false, // Never allow inbound by default
}
}