fetch_ml/internal/worker/gpu.go
Jeremie Fraeys c46be7f815
refactor: Phase 4 deferred - Extract GPU utilities and execution helpers
Extracted from execution.go to focused packages:

1. internal/worker/gpu.go (60 lines)
   - gpuVisibleDevicesString() - GPU device string formatting
   - filterExistingDevicePaths() - Device path filtering
   - gpuVisibleEnvVarName() - GPU env var selection
   - Reuses GPUType constants from gpu_detector.go

2. internal/worker/execution/setup.go (108 lines)
   - SetupJobDirectories() - Job directory creation
   - CopyDir() - Directory tree copying
   - copyFile() - Single file copy helper

3. internal/worker/execution/snapshot.go (52 lines)
   - StageSnapshot() - Snapshot staging for jobs
   - StageSnapshotFromPath() - Snapshot staging from path

Updated execution.go:
- Removed 64 lines of GPU utilities (now in gpu.go)
- Reduced from 1,082 to ~1,018 lines
- Still contains main execution flow (runJob, executeJob, etc.)

Build status: Compiles successfully
2026-02-17 14:03:11 -05:00

75 lines
1.7 KiB
Go

package worker
import (
"os"
"strconv"
"strings"
)
// gpuVisibleDevicesString constructs the visible devices string from config
func gpuVisibleDevicesString(cfg *Config, fallback string) string {
if cfg == nil {
return strings.TrimSpace(fallback)
}
if len(cfg.GPUVisibleDeviceIDs) > 0 {
parts := make([]string, 0, len(cfg.GPUVisibleDeviceIDs))
for _, id := range cfg.GPUVisibleDeviceIDs {
id = strings.TrimSpace(id)
if id == "" {
continue
}
parts = append(parts, id)
}
return strings.Join(parts, ",")
}
if len(cfg.GPUVisibleDevices) == 0 {
return strings.TrimSpace(fallback)
}
parts := make([]string, 0, len(cfg.GPUVisibleDevices))
for _, v := range cfg.GPUVisibleDevices {
if v < 0 {
continue
}
parts = append(parts, strconv.Itoa(v))
}
return strings.Join(parts, ",")
}
// filterExistingDevicePaths filters device paths that actually exist
func filterExistingDevicePaths(paths []string) []string {
if len(paths) == 0 {
return nil
}
seen := make(map[string]struct{}, len(paths))
out := make([]string, 0, len(paths))
for _, p := range paths {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if _, ok := seen[p]; ok {
continue
}
if _, err := os.Stat(p); err != nil {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
return out
}
// gpuVisibleEnvVarName returns the appropriate env var for GPU visibility
func gpuVisibleEnvVarName(cfg *Config) string {
if cfg == nil {
return "CUDA_VISIBLE_DEVICES"
}
switch strings.ToLower(strings.TrimSpace(cfg.GPUVendor)) {
case "amd":
return "HIP_VISIBLE_DEVICES"
case string(GPUTypeApple), string(GPUTypeNone):
return ""
default:
return "CUDA_VISIBLE_DEVICES"
}
}