// Package experiment provides ML experiment management package experiment import ( "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "fmt" "math" "os" "path/filepath" "strings" "time" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/worker/integrity" ) // Manifest represents a content integrity manifest for experiment files type Manifest struct { CommitID string `json:"commit_id"` Files map[string]string `json:"files"` // relative path -> sha256 hex OverallSHA string `json:"overall_sha"` // sha256 of concatenated file hashes Timestamp int64 `json:"timestamp"` } // Metadata represents experiment metadata stored in meta.bin type Metadata struct { CommitID string JobName string User string Timestamp int64 } // Manager handles experiment storage and metadata type Manager struct { basePath string } // NewManager creates a new experiment manager. func NewManager(basePath string) *Manager { return &Manager{ basePath: basePath, } } // NewManagerFromPaths creates a new experiment manager using PathRegistry. func NewManagerFromPaths(paths *config.PathRegistry) *Manager { return &Manager{ basePath: paths.ExperimentsDir(), } } func (m *Manager) BasePath() string { if m == nil { return "" } return m.basePath } // Initialize ensures the experiment directory exists func (m *Manager) Initialize() error { // Use PathRegistry for consistent path management paths := config.FromEnv() if err := paths.EnsureDir(paths.ExperimentsDir()); err != nil { return fmt.Errorf("failed to create experiment base directory: %w", err) } return nil } // GetExperimentPath returns the path for a given commit ID func (m *Manager) GetExperimentPath(commitID string) string { return filepath.Join(m.basePath, commitID) } // GetFilesPath returns the path to the files directory for an experiment func (m *Manager) GetFilesPath(commitID string) string { return filepath.Join(m.GetExperimentPath(commitID), "files") } // GetMetadataPath returns the path to meta.bin for an experiment func (m *Manager) GetMetadataPath(commitID string) string { return filepath.Join(m.GetExperimentPath(commitID), "meta.bin") } // ExperimentExists checks if an experiment with the given commit ID exists func (m *Manager) ExperimentExists(commitID string) bool { path := m.GetExperimentPath(commitID) info, err := os.Stat(path) return err == nil && info.IsDir() } // CreateExperiment creates the directory structure for a new experiment func (m *Manager) CreateExperiment(commitID string) error { filesPath := m.GetFilesPath(commitID) if err := os.MkdirAll(filesPath, 0o750); err != nil { return fmt.Errorf("failed to create experiment directory: %w", err) } return nil } // WriteMetadata writes experiment metadata to meta.bin with crash safety (fsync) func (m *Manager) WriteMetadata(meta *Metadata) error { path := m.GetMetadataPath(meta.CommitID) // Binary format: // [version:1][timestamp:8][commit_id_len:1][commit_id:var][job_name_len:1][job_name:var] // [user_len:1][user:var] buf := make([]byte, 0, 256) // Version buf = append(buf, 0x01) // Timestamp ts := make([]byte, 8) binary.BigEndian.PutUint64(ts, uint64(meta.Timestamp)) //nolint:gosec buf = append(buf, ts...) // Commit ID buf = append(buf, byte(len(meta.CommitID))) buf = append(buf, []byte(meta.CommitID)...) // Job Name buf = append(buf, byte(len(meta.JobName))) buf = append(buf, []byte(meta.JobName)...) // User buf = append(buf, byte(len(meta.User))) buf = append(buf, []byte(meta.User)...) // SECURITY: Write with fsync for crash safety return fileutil.WriteFileSafe(path, buf, 0o600) } // ReadMetadata reads experiment metadata from meta.bin func (m *Manager) ReadMetadata(commitID string) (*Metadata, error) { path := m.GetMetadataPath(commitID) data, err := fileutil.SecureFileRead(path) if err != nil { return nil, fmt.Errorf("failed to read metadata: %w", err) } if len(data) < 10 { return nil, fmt.Errorf("metadata file too short") } meta := &Metadata{} offset := 0 // Version version := data[offset] offset++ if version != 0x01 { return nil, fmt.Errorf("unsupported metadata version: %d", version) } // Timestamp meta.Timestamp = int64(binary.BigEndian.Uint64(data[offset : offset+8])) //nolint:gosec offset += 8 // Commit ID commitIDLen := int(data[offset]) offset++ meta.CommitID = string(data[offset : offset+commitIDLen]) offset += commitIDLen // Job Name if offset >= len(data) { return meta, nil } jobNameLen := int(data[offset]) offset++ meta.JobName = string(data[offset : offset+jobNameLen]) offset += jobNameLen // User if offset >= len(data) { return meta, nil } userLen := int(data[offset]) offset++ meta.User = string(data[offset : offset+userLen]) return meta, nil } // ListExperiments returns all experiment commit IDs func (m *Manager) ListExperiments() ([]string, error) { entries, err := os.ReadDir(m.basePath) if err != nil { return nil, fmt.Errorf("failed to read experiments directory: %w", err) } var commitIDs []string for _, entry := range entries { if entry.IsDir() { if entry.Name() == "archive" { continue } commitIDs = append(commitIDs, entry.Name()) } } return commitIDs, nil } func (m *Manager) archiveExperiment(commitID string) (string, error) { if m == nil { return "", fmt.Errorf("missing manager") } commitID = strings.TrimSpace(commitID) if err := container.ValidateJobName(commitID); err != nil { return "", fmt.Errorf("invalid commit id: %w", err) } src := m.GetExperimentPath(commitID) info, err := os.Stat(src) if err != nil { return "", err } if !info.IsDir() { return "", fmt.Errorf("experiment path is not a directory") } stamp := time.Now().UTC().Format("20060102-150405") archiveRoot := filepath.Join(m.basePath, "archive", stamp) // Use PathRegistry pattern for directory creation paths := config.FromEnv() if err := paths.EnsureDir(filepath.Join(paths.ExperimentsDir(), "archive")); err != nil { return "", err } if err := os.MkdirAll(archiveRoot, 0o750); err != nil { return "", err } dst := filepath.Join(archiveRoot, commitID) if err := os.Rename(src, dst); err != nil { return "", err } return dst, nil } // PruneExperiments removes old experiments based on retention policy func (m *Manager) PruneExperiments(keepCount int, olderThanDays int) ([]string, error) { commitIDs, err := m.ListExperiments() if err != nil { return nil, err } type experiment struct { commitID string timestamp int64 } var experiments []experiment for _, commitID := range commitIDs { meta, err := m.ReadMetadata(commitID) if err != nil { continue // Skip experiments with invalid metadata } experiments = append(experiments, experiment{ commitID: commitID, timestamp: meta.Timestamp, }) } // Sort by timestamp (newest first) for i := 0; i < len(experiments); i++ { for j := i + 1; j < len(experiments); j++ { if experiments[j].timestamp > experiments[i].timestamp { experiments[i], experiments[j] = experiments[j], experiments[i] } } } var pruned []string cutoffTime := time.Now().AddDate(0, 0, -olderThanDays).Unix() for i, exp := range experiments { shouldPrune := false // Keep the newest N experiments if i >= keepCount { shouldPrune = true } // Also prune if older than threshold if olderThanDays > 0 && exp.timestamp < cutoffTime { shouldPrune = true } if shouldPrune { if _, err := m.archiveExperiment(exp.commitID); err != nil { continue } pruned = append(pruned, exp.commitID) } } return pruned, nil } // Metric represents a single data point in an experiment type Metric struct { Name string `json:"name"` Value float64 `json:"value"` Step int `json:"step"` Timestamp int64 `json:"timestamp"` } // GetMetricsPath returns the path to metrics.bin for an experiment func (m *Manager) GetMetricsPath(commitID string) string { return filepath.Join(m.GetExperimentPath(commitID), "metrics.bin") } // LogMetric appends a metric to the experiment's metrics file func (m *Manager) LogMetric(commitID string, name string, value float64, step int) error { path := m.GetMetricsPath(commitID) // Ensure the experiment directory exists if err := os.MkdirAll(m.GetExperimentPath(commitID), 0o750); err != nil { return fmt.Errorf("failed to create experiment directory: %w", err) } // Binary format for each metric: // [timestamp:8][step:4][value:8][name_len:1][name:var] buf := make([]byte, 0, 64) // Timestamp ts := make([]byte, 8) ts64 := uint64(time.Now().Unix()) //nolint:gosec binary.BigEndian.PutUint64(ts, ts64) buf = append(buf, ts...) // Step st := make([]byte, 4) binary.BigEndian.PutUint32(st, uint32(step)) //nolint:gosec buf = append(buf, st...) // Value (float64) val := make([]byte, 8) binary.BigEndian.PutUint64(val, math.Float64bits(value)) buf = append(buf, val...) // Name if len(name) > 255 { name = name[:255] } buf = append(buf, byte(len(name))) buf = append(buf, []byte(name)...) // Append to file f, err := fileutil.SecureOpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) if err != nil { return fmt.Errorf("failed to open metrics file: %w", err) } defer func() { _ = f.Close() }() if _, err := f.Write(buf); err != nil { return fmt.Errorf("failed to write metric: %w", err) } return nil } // GetMetrics reads all metrics for an experiment func (m *Manager) GetMetrics(commitID string) ([]Metric, error) { path := m.GetMetricsPath(commitID) data, err := fileutil.SecureFileRead(path) if err != nil { if os.IsNotExist(err) { return []Metric{}, nil } return nil, fmt.Errorf("failed to read metrics file: %w", err) } var metrics []Metric offset := 0 for offset < len(data) { if offset+21 > len(data) { // Min size check break } m := Metric{} // Timestamp ts := binary.BigEndian.Uint64(data[offset : offset+8]) if ts > math.MaxInt64 { return nil, fmt.Errorf("timestamp overflow") } m.Timestamp = int64(ts) offset += 8 // Step m.Step = int(binary.BigEndian.Uint32(data[offset : offset+4])) offset += 4 // Value bits := binary.BigEndian.Uint64(data[offset : offset+8]) m.Value = math.Float64frombits(bits) offset += 8 // Name nameLen := int(data[offset]) offset++ if offset+nameLen > len(data) { break } m.Name = string(data[offset : offset+nameLen]) offset += nameLen metrics = append(metrics, m) } return metrics, nil } // GetManifestPath returns the path to the manifest file for an experiment func (m *Manager) GetManifestPath(commitID string) string { return filepath.Join(m.GetExperimentPath(commitID), "manifest.json") } // GenerateManifest creates a content integrity manifest for all files in the experiment directory func (m *Manager) GenerateManifest(commitID string) (*Manifest, error) { filesPath := m.GetFilesPath(commitID) // Check if files directory exists if _, err := os.Stat(filesPath); os.IsNotExist(err) { return nil, fmt.Errorf("files directory does not exist: %s", filesPath) } manifest := &Manifest{ CommitID: commitID, Files: make(map[string]string), Timestamp: time.Now().Unix(), } // Walk the files directory and hash each file err := filepath.Walk(filesPath, func(path string, info os.FileInfo, err error) error { if err != nil { return err } // Skip directories if info.IsDir() { return nil } // Get relative path from files directory relPath, err := filepath.Rel(filesPath, path) if err != nil { return fmt.Errorf("failed to get relative path for %s: %w", path, err) } // Calculate SHA256 of file hash, err := m.hashFile(path) if err != nil { return fmt.Errorf("failed to hash file %s: %w", path, err) } manifest.Files[relPath] = hash return nil }) if err != nil { return nil, fmt.Errorf("failed to walk files directory: %w", err) } // Calculate overall SHA256 of concatenated file hashes (sorted by path for determinism) manifest.OverallSHA = m.calculateOverallSHA(manifest.Files) return manifest, nil } // WriteManifest persists the manifest to disk func (m *Manager) WriteManifest(manifest *Manifest) error { path := m.GetManifestPath(manifest.CommitID) data, err := json.MarshalIndent(manifest, "", " ") if err != nil { return fmt.Errorf("failed to marshal manifest: %w", err) } if err := fileutil.SecureFileWrite(path, data, 0o640); err != nil { return fmt.Errorf("failed to write manifest file: %w", err) } return nil } // ReadManifest loads the manifest from disk func (m *Manager) ReadManifest(commitID string) (*Manifest, error) { path := m.GetManifestPath(commitID) data, err := fileutil.SecureFileRead(path) if err != nil { return nil, fmt.Errorf("failed to read manifest file: %w", err) } var manifest Manifest if err := json.Unmarshal(data, &manifest); err != nil { return nil, fmt.Errorf("failed to unmarshal manifest: %w", err) } return &manifest, nil } // ValidateManifest verifies that the current files match the stored manifest func (m *Manager) ValidateManifest(commitID string) error { // Read stored manifest stored, err := m.ReadManifest(commitID) if err != nil { return fmt.Errorf("failed to read stored manifest: %w", err) } // Generate manifest from current files current, err := m.GenerateManifest(commitID) if err != nil { return fmt.Errorf("failed to generate current manifest: %w", err) } // Verify commit ID matches if stored.CommitID != current.CommitID { return fmt.Errorf("commit ID mismatch: stored=%s, current=%s", stored.CommitID, current.CommitID) } // Verify overall SHA matches if stored.OverallSHA != current.OverallSHA { return fmt.Errorf( "overall integrity checksum mismatch: stored=%s, current=%s", stored.OverallSHA, current.OverallSHA, ) } // Verify file count matches if len(stored.Files) != len(current.Files) { return fmt.Errorf( "file count mismatch: stored=%d, current=%d", len(stored.Files), len(current.Files), ) } // Verify each file hash matches for relPath, storedHash := range stored.Files { currentHash, exists := current.Files[relPath] if !exists { return fmt.Errorf("file missing in current manifest: %s", relPath) } if storedHash != currentHash { return fmt.Errorf( "file hash mismatch for %s: stored=%s, current=%s", relPath, storedHash, currentHash, ) } } return nil } // hashFile calculates SHA256 hash of a file // This delegates to the integrity package for consistent hashing. func (m *Manager) hashFile(path string) (string, error) { return integrity.FileSHA256Hex(path) } // calculateOverallSHA calculates deterministic SHA256 of all file hashes func (m *Manager) calculateOverallSHA(files map[string]string) string { // Sort paths for deterministic ordering paths := make([]string, 0, len(files)) for path := range files { paths = append(paths, path) } // Simple bubble sort for small lists (deterministic) for i := 0; i < len(paths); i++ { for j := i + 1; j < len(paths); j++ { if paths[i] > paths[j] { paths[i], paths[j] = paths[j], paths[i] } } } // Concatenate all hashes var combined strings.Builder for _, path := range paths { combined.WriteString(files[path]) } // Calculate SHA256 of the combined string hasher := sha256.New() hasher.Write([]byte(combined.String())) return hex.EncodeToString(hasher.Sum(nil)) }