From dab680a60d9eb0f8a66fb3339e1ea2c5d5a2f55e Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Mon, 5 Jan 2026 12:33:57 -0500 Subject: [PATCH] feat(tracking): add pluggable tracking backends and audit support --- internal/audit/audit.go | 171 +++++++++++++++++++ internal/tracking/factory/loader.go | 116 +++++++++++++ internal/tracking/plugin.go | 209 +++++++++++++++++++++++ internal/tracking/plugins/mlflow.go | 196 +++++++++++++++++++++ internal/tracking/plugins/tensorboard.go | 170 ++++++++++++++++++ internal/tracking/plugins/wandb.go | 69 ++++++++ 6 files changed, 931 insertions(+) create mode 100644 internal/audit/audit.go create mode 100644 internal/tracking/factory/loader.go create mode 100644 internal/tracking/plugin.go create mode 100644 internal/tracking/plugins/mlflow.go create mode 100644 internal/tracking/plugins/tensorboard.go create mode 100644 internal/tracking/plugins/wandb.go diff --git a/internal/audit/audit.go b/internal/audit/audit.go new file mode 100644 index 0000000..822f797 --- /dev/null +++ b/internal/audit/audit.go @@ -0,0 +1,171 @@ +package audit + +import ( + "encoding/json" + "fmt" + "os" + "sync" + "time" + + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// EventType represents the type of audit event +type EventType string + +const ( + EventAuthAttempt EventType = "authentication_attempt" + EventAuthSuccess EventType = "authentication_success" + EventAuthFailure EventType = "authentication_failure" + EventJobQueued EventType = "job_queued" + EventJobStarted EventType = "job_started" + EventJobCompleted EventType = "job_completed" + EventJobFailed EventType = "job_failed" + EventJupyterStart EventType = "jupyter_start" + EventJupyterStop EventType = "jupyter_stop" + EventExperimentCreated EventType = "experiment_created" + EventExperimentDeleted EventType = "experiment_deleted" +) + +// Event represents an audit log event +type Event struct { + Timestamp time.Time `json:"timestamp"` + EventType EventType `json:"event_type"` + UserID string `json:"user_id,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + Resource string `json:"resource,omitempty"` + Action string `json:"action,omitempty"` + Success bool `json:"success"` + ErrorMsg string `json:"error,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// Logger handles audit logging +type Logger struct { + enabled bool + filePath string + file *os.File + mu sync.Mutex + logger *logging.Logger +} + +// NewLogger creates a new audit logger +func NewLogger(enabled bool, filePath string, logger *logging.Logger) (*Logger, error) { + al := &Logger{ + enabled: enabled, + filePath: filePath, + logger: logger, + } + + if enabled && filePath != "" { + file, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open audit log file: %w", err) + } + al.file = file + } + + return al, nil +} + +// Log logs an audit event +func (al *Logger) Log(event Event) { + if !al.enabled { + return + } + + event.Timestamp = time.Now().UTC() + + al.mu.Lock() + defer al.mu.Unlock() + + // Marshal to JSON + data, err := json.Marshal(event) + if err != nil { + if al.logger != nil { + al.logger.Error("failed to marshal audit event", "error", err) + } + return + } + + // Write to file if configured + if al.file != nil { + _, err = al.file.Write(append(data, '\n')) + if err != nil && al.logger != nil { + al.logger.Error("failed to write audit event", "error", err) + } + } + + // Also log via structured logger + if al.logger != nil { + al.logger.Info("audit_event", + "event_type", event.EventType, + "user_id", event.UserID, + "resource", event.Resource, + "success", event.Success, + ) + } +} + +// LogAuthAttempt logs an authentication attempt +func (al *Logger) LogAuthAttempt(userID, ipAddr string, success bool, errMsg string) { + eventType := EventAuthSuccess + if !success { + eventType = EventAuthFailure + } + + al.Log(Event{ + EventType: eventType, + UserID: userID, + IPAddress: ipAddr, + Success: success, + ErrorMsg: errMsg, + }) +} + +// LogJobOperation logs a job-related operation +func (al *Logger) LogJobOperation( + eventType EventType, + userID, jobID, ipAddr string, + success bool, + errMsg string, +) { + al.Log(Event{ + EventType: eventType, + UserID: userID, + IPAddress: ipAddr, + Resource: jobID, + Action: "job_operation", + Success: success, + ErrorMsg: errMsg, + }) +} + +// LogJupyterOperation logs a Jupyter service operation +func (al *Logger) LogJupyterOperation( + eventType EventType, + userID, serviceID, ipAddr string, + success bool, + errMsg string, +) { + al.Log(Event{ + EventType: eventType, + UserID: userID, + IPAddress: ipAddr, + Resource: serviceID, + Action: "jupyter_operation", + Success: success, + ErrorMsg: errMsg, + }) +} + +// Close closes the audit logger +func (al *Logger) Close() error { + al.mu.Lock() + defer al.mu.Unlock() + + if al.file != nil { + return al.file.Close() + } + return nil +} diff --git a/internal/tracking/factory/loader.go b/internal/tracking/factory/loader.go new file mode 100644 index 0000000..e8d92e3 --- /dev/null +++ b/internal/tracking/factory/loader.go @@ -0,0 +1,116 @@ +package factory + +import ( + "fmt" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/tracking" + "github.com/jfraeys/fetch_ml/internal/tracking/plugins" +) + +// PluginConfig represents the configuration for a single plugin. +type PluginConfig struct { + Enabled bool `toml:"enabled" yaml:"enabled"` + Image string `toml:"image" yaml:"image"` + Mode string `toml:"mode" yaml:"mode"` + LogBasePath string `toml:"log_base_path" yaml:"log_base_path"` + ArtifactPath string `toml:"artifact_path" yaml:"artifact_path"` + Settings map[string]any `toml:"settings" yaml:"settings"` +} + +// PluginFactory is a function that creates a Plugin instance. +type PluginFactory func( + logger *logging.Logger, + podman *container.PodmanManager, + cfg PluginConfig, +) (tracking.Plugin, error) + +// PluginLoader uses dependency injection to load plugins. +type PluginLoader struct { + logger *logging.Logger + podman *container.PodmanManager + factories map[string]PluginFactory +} + +// NewPluginLoader creates a new PluginLoader. +func NewPluginLoader(logger *logging.Logger, podman *container.PodmanManager) *PluginLoader { + loader := &PluginLoader{ + logger: logger, + podman: podman, + factories: make(map[string]PluginFactory), + } + + // Register default factories + loader.RegisterFactory("mlflow", createMLflowPlugin) + loader.RegisterFactory("tensorboard", createTensorBoardPlugin) + loader.RegisterFactory("wandb", createWandbPlugin) + + return loader +} + +// RegisterFactory allows external packages to register new plugin types (Marketplace ready). +func (l *PluginLoader) RegisterFactory(name string, factory PluginFactory) { + l.factories[name] = factory +} + +// LoadPlugins loads plugins from the provided configuration map and registers them. +func (l *PluginLoader) LoadPlugins( + plugins map[string]PluginConfig, + registry *tracking.Registry, +) error { + for name, pluginCfg := range plugins { + if !pluginCfg.Enabled { + continue + } + + factory, ok := l.factories[name] + if !ok { + l.logger.Warn("unknown plugin type", "name", name) + continue + } + + plugin, err := factory(l.logger, l.podman, pluginCfg) + if err != nil { + return fmt.Errorf("failed to create plugin %s: %w", name, err) + } + + registry.Register(plugin) + l.logger.Info("plugin loaded", "name", name, "mode", pluginCfg.Mode) + } + + return nil +} + +// Factory Implementations +func createMLflowPlugin( + logger *logging.Logger, + podman *container.PodmanManager, + cfg PluginConfig, +) (tracking.Plugin, error) { + opts := plugins.MLflowOptions{ + Image: cfg.Image, + ArtifactBasePath: cfg.ArtifactPath, + } + return plugins.NewMLflowPlugin(logger, podman, opts) +} + +func createTensorBoardPlugin( + logger *logging.Logger, + podman *container.PodmanManager, + cfg PluginConfig, +) (tracking.Plugin, error) { + opts := plugins.TensorBoardOptions{ + Image: cfg.Image, + LogBasePath: cfg.LogBasePath, + } + return plugins.NewTensorBoardPlugin(logger, podman, opts) +} + +func createWandbPlugin( + logger *logging.Logger, + _ *container.PodmanManager, + _ PluginConfig, +) (tracking.Plugin, error) { + return plugins.NewWandbPlugin(), nil +} diff --git a/internal/tracking/plugin.go b/internal/tracking/plugin.go new file mode 100644 index 0000000..d714a17 --- /dev/null +++ b/internal/tracking/plugin.go @@ -0,0 +1,209 @@ +package tracking + +import ( + "context" + "fmt" + "sync" + + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// ToolMode represents the provisioning mode for a tracking tool. +type ToolMode string + +const ( + // ModeSidecar provisions the tool as a sidecar container. + ModeSidecar ToolMode = "sidecar" + // ModeRemote points to a remotely managed instance (no local provisioning). + ModeRemote ToolMode = "remote" + // ModeDisabled skips provisioning entirely. + ModeDisabled ToolMode = "disabled" +) + +// ToolConfig specifies how a plugin should be provisioned for a task. +type ToolConfig struct { + Enabled bool + Mode ToolMode + Settings map[string]any +} + +// Plugin defines the behaviour every tracking integration must implement. +type Plugin interface { + Name() string + ProvisionSidecar(ctx context.Context, taskID string, config ToolConfig) (map[string]string, error) + Teardown(ctx context.Context, taskID string) error + HealthCheck(ctx context.Context, config ToolConfig) bool +} + +// Registry keeps track of registered plugins and their lifecycle per task. +type Registry struct { + logger *logging.Logger + mu sync.Mutex + plugins map[string]Plugin + active map[string][]string +} + +// NewRegistry returns a new plugin registry. +func NewRegistry(logger *logging.Logger) *Registry { + return &Registry{ + logger: logger, + plugins: make(map[string]Plugin), + active: make(map[string][]string), + } +} + +// Register adds a plugin to the registry. +func (r *Registry) Register(p Plugin) { + r.mu.Lock() + defer r.mu.Unlock() + r.plugins[p.Name()] = p +} + +// Get retrieves a plugin by name. +func (r *Registry) Get(name string) (Plugin, bool) { + r.mu.Lock() + defer r.mu.Unlock() + p, ok := r.plugins[name] + return p, ok +} + +// ProvisionAll provisions configured plugins for a task and merges their environment variables. +func (r *Registry) ProvisionAll( + ctx context.Context, + taskID string, + configs map[string]ToolConfig, +) (map[string]string, error) { + if len(configs) == 0 { + return nil, nil + } + + env := make(map[string]string) + var provisioned []string + + for name, cfg := range configs { + if !cfg.Enabled || cfg.Mode == ModeDisabled { + continue + } + + plugin, ok := r.Get(name) + if !ok { + return nil, fmt.Errorf("tracking plugin %s not registered", name) + } + + settingsEnv, err := plugin.ProvisionSidecar(ctx, taskID, cfg) + if err != nil { + r.rollback(ctx, taskID, provisioned) + return nil, fmt.Errorf("failed to provision %s: %w", name, err) + } + + for k, v := range settingsEnv { + env[k] = v + } + + if cfg.Mode == ModeSidecar { + provisioned = append(provisioned, name) + } + } + + if len(provisioned) > 0 { + r.mu.Lock() + r.active[taskID] = append(r.active[taskID], provisioned...) + r.mu.Unlock() + } + + return env, nil +} + +// TeardownAll stops every plugin that was provisioned for a task. +func (r *Registry) TeardownAll(ctx context.Context, taskID string) { + r.mu.Lock() + plugins := r.active[taskID] + delete(r.active, taskID) + r.mu.Unlock() + + for _, name := range plugins { + plugin, ok := r.Get(name) + if !ok { + continue + } + if err := plugin.Teardown(ctx, taskID); err != nil && r.logger != nil { + r.logger.Warn("tracking teardown failed", "plugin", name, "task_id", taskID, "error", err) + } + } +} + +func (r *Registry) rollback(ctx context.Context, taskID string, provisioned []string) { + for i := len(provisioned) - 1; i >= 0; i-- { + name := provisioned[i] + plugin, ok := r.Get(name) + if !ok { + continue + } + if err := plugin.Teardown(ctx, taskID); err != nil && r.logger != nil { + r.logger.Warn("rollback failed", "plugin", name, "task_id", taskID, "error", err) + } + } +} + +// PortAllocator manages dynamic port assignments for sidecars. +type PortAllocator struct { + mu sync.Mutex + start int + end int + next int + used map[int]bool +} + +// NewPortAllocator creates a new allocator for a port range. +func NewPortAllocator(start, end int) *PortAllocator { + if start <= 0 || end <= 0 || end <= start { + start = 5500 + end = 5600 + } + return &PortAllocator{ + start: start, + end: end, + next: start, + used: make(map[int]bool), + } +} + +// Allocate reserves the next available port. +func (p *PortAllocator) Allocate() (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + + for i := 0; i < p.end-p.start; i++ { + port := p.next + p.next++ + if p.next >= p.end { + p.next = p.start + } + if !p.used[port] { + p.used[port] = true + return port, nil + } + } + + return 0, fmt.Errorf("no ports available in range %d-%d", p.start, p.end) +} + +// Release frees a previously allocated port. +func (p *PortAllocator) Release(port int) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.used, port) +} + +// StringSetting safely reads a string from plugin settings. +func StringSetting(settings map[string]any, key string) string { + if settings == nil { + return "" + } + if v, ok := settings[key]; ok { + if str, ok := v.(string); ok { + return str + } + } + return "" +} diff --git a/internal/tracking/plugins/mlflow.go b/internal/tracking/plugins/mlflow.go new file mode 100644 index 0000000..86fd820 --- /dev/null +++ b/internal/tracking/plugins/mlflow.go @@ -0,0 +1,196 @@ +package plugins + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/logging" + trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking" +) + +// MLflowOptions configures the MLflow plugin. +type MLflowOptions struct { + Image string + ArtifactBasePath string + DefaultTrackingURI string + PortAllocator *trackingpkg.PortAllocator +} + +type mlflowSidecar struct { + containerID string + port int +} + +// MLflowPlugin provisions MLflow tracking servers per task. +type MLflowPlugin struct { + logger *logging.Logger + podman *container.PodmanManager + opts MLflowOptions + + mu sync.Mutex + sidecars map[string]*mlflowSidecar +} + +// NewMLflowPlugin creates a new MLflow plugin instance. +func NewMLflowPlugin( + logger *logging.Logger, + podman *container.PodmanManager, + opts MLflowOptions, +) (*MLflowPlugin, error) { + if podman == nil { + return nil, fmt.Errorf("podman manager is required for MLflow plugin") + } + if opts.Image == "" { + opts.Image = "ghcr.io/mlflow/mlflow:v2.16.1" + } + if opts.ArtifactBasePath == "" { + return nil, fmt.Errorf("artifact base path is required for MLflow plugin") + } + if opts.PortAllocator == nil { + opts.PortAllocator = trackingpkg.NewPortAllocator(5500, 5700) + } + + return &MLflowPlugin{ + logger: logger, + podman: podman, + opts: opts, + sidecars: make(map[string]*mlflowSidecar), + }, nil +} + +// Name returns the plugin name. +func (m *MLflowPlugin) Name() string { + return "mlflow" +} + +// ProvisionSidecar starts an MLflow sidecar or returns remote env vars. +func (m *MLflowPlugin) ProvisionSidecar( + ctx context.Context, + taskID string, + config trackingpkg.ToolConfig, +) (map[string]string, error) { + switch config.Mode { + case trackingpkg.ModeRemote: + uri := trackingpkg.StringSetting(config.Settings, "tracking_uri") + if uri == "" { + uri = m.opts.DefaultTrackingURI + } + if uri == "" { + return nil, fmt.Errorf("mlflow remote mode requires tracking_uri") + } + return map[string]string{ + "MLFLOW_TRACKING_URI": uri, + }, nil + case trackingpkg.ModeDisabled: + return nil, nil + default: + return m.provisionSidecar(ctx, taskID, config) + } +} + +func (m *MLflowPlugin) provisionSidecar( + ctx context.Context, + taskID string, + config trackingpkg.ToolConfig, +) (map[string]string, error) { + jobName := trackingpkg.StringSetting(config.Settings, "job_name") + if jobName == "" { + jobName = taskID + } + + taskDir := filepath.Join(m.opts.ArtifactBasePath, jobName) + if err := os.MkdirAll(taskDir, 0750); err != nil { + return nil, fmt.Errorf("failed to create MLflow artifact dir: %w", err) + } + + port, err := m.opts.PortAllocator.Allocate() + if err != nil { + return nil, err + } + + containerName := fmt.Sprintf("mlflow-%s", taskID) + cmd := []string{ + "mlflow", "server", + "--host=0.0.0.0", + "--port=5000", + "--backend-store-uri=file:///mlruns", + "--default-artifact-root=file:///mlruns", + } + + cfg := &container.ContainerConfig{ + Name: containerName, + Image: m.opts.Image, + Command: cmd, + Env: map[string]string{ + "MLFLOW_BACKEND_STORE_URI": "file:///mlruns", + "MLFLOW_ARTIFACT_URI": "file:///mlruns", + }, + Volumes: map[string]string{ + taskDir: "/mlruns", + }, + Ports: map[int]int{ + port: 5000, + }, + Network: container.NetworkConfig{ + AllowNetwork: true, + }, + } + + containerID, err := m.podman.StartContainer(ctx, cfg) + if err != nil { + m.opts.PortAllocator.Release(port) + return nil, err + } + + m.mu.Lock() + m.sidecars[taskID] = &mlflowSidecar{ + containerID: containerID, + port: port, + } + m.mu.Unlock() + + return map[string]string{ + "MLFLOW_TRACKING_URI": fmt.Sprintf("http://127.0.0.1:%d", port), + }, nil +} + +// Teardown stops the MLflow sidecar for a task. +func (m *MLflowPlugin) Teardown(ctx context.Context, taskID string) error { + m.mu.Lock() + sidecar, ok := m.sidecars[taskID] + if ok { + delete(m.sidecars, taskID) + } + m.mu.Unlock() + + if !ok { + return nil + } + + if sidecar.port != 0 { + m.opts.PortAllocator.Release(sidecar.port) + } + + if sidecar.containerID == "" { + return nil + } + + if err := m.podman.StopContainer(ctx, sidecar.containerID); err != nil && m.logger != nil { + m.logger.Warn("failed to stop MLflow container", "task_id", taskID, "error", err) + } + if err := m.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && m.logger != nil { + m.logger.Warn("failed to remove MLflow container", "task_id", taskID, "error", err) + } + + return nil +} + +// HealthCheck verifies plugin readiness (basic for now). +func (m *MLflowPlugin) HealthCheck(_ context.Context, _ trackingpkg.ToolConfig) bool { + // Future: Perform HTTP checks against tracking URI. + return true +} diff --git a/internal/tracking/plugins/tensorboard.go b/internal/tracking/plugins/tensorboard.go new file mode 100644 index 0000000..3930431 --- /dev/null +++ b/internal/tracking/plugins/tensorboard.go @@ -0,0 +1,170 @@ +package plugins + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/logging" + trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking" +) + +// TensorBoardOptions configure the TensorBoard plugin. +type TensorBoardOptions struct { + Image string + LogBasePath string + PortAllocator *trackingpkg.PortAllocator +} + +type tensorboardSidecar struct { + containerID string + port int +} + +// TensorBoardPlugin exposes training logs through TensorBoard. +type TensorBoardPlugin struct { + logger *logging.Logger + podman *container.PodmanManager + opts TensorBoardOptions + + mu sync.Mutex + sidecars map[string]*tensorboardSidecar +} + +// NewTensorBoardPlugin constructs a TensorBoard plugin instance. +func NewTensorBoardPlugin( + logger *logging.Logger, + podman *container.PodmanManager, + opts TensorBoardOptions, +) (*TensorBoardPlugin, error) { + if podman == nil { + return nil, fmt.Errorf("podman manager is required for TensorBoard plugin") + } + if opts.Image == "" { + opts.Image = "tensorflow/tensorflow:2.17.0" + } + if opts.LogBasePath == "" { + return nil, fmt.Errorf("log base path is required for TensorBoard plugin") + } + if opts.PortAllocator == nil { + opts.PortAllocator = trackingpkg.NewPortAllocator(5700, 5900) + } + + return &TensorBoardPlugin{ + logger: logger, + podman: podman, + opts: opts, + sidecars: make(map[string]*tensorboardSidecar), + }, nil +} + +// Name returns the plugin name. +func (t *TensorBoardPlugin) Name() string { + return "tensorboard" +} + +// ProvisionSidecar starts TensorBoard to tail a task's log directory. +func (t *TensorBoardPlugin) ProvisionSidecar( + ctx context.Context, + taskID string, + config trackingpkg.ToolConfig, +) (map[string]string, error) { + if config.Mode == trackingpkg.ModeRemote || config.Mode == trackingpkg.ModeDisabled { + return nil, nil + } + + jobName := trackingpkg.StringSetting(config.Settings, "job_name") + if jobName == "" { + jobName = taskID + } + + logDir := filepath.Join(t.opts.LogBasePath, jobName) + if err := os.MkdirAll(logDir, 0750); err != nil { + return nil, fmt.Errorf("failed to ensure tensorboard log dir: %w", err) + } + + port, err := t.opts.PortAllocator.Allocate() + if err != nil { + return nil, err + } + + containerName := fmt.Sprintf("tensorboard-%s", taskID) + cmd := []string{ + "tensorboard", + "--logdir", "/logs", + "--host", "0.0.0.0", + "--port", "6006", + } + + cfg := &container.ContainerConfig{ + Name: containerName, + Image: t.opts.Image, + Command: cmd, + Volumes: map[string]string{ + logDir: "/logs:ro", + }, + Ports: map[int]int{ + port: 6006, + }, + Network: container.NetworkConfig{ + AllowNetwork: true, + }, + } + + containerID, err := t.podman.StartContainer(ctx, cfg) + if err != nil { + t.opts.PortAllocator.Release(port) + return nil, err + } + + t.mu.Lock() + t.sidecars[taskID] = &tensorboardSidecar{ + containerID: containerID, + port: port, + } + t.mu.Unlock() + + return map[string]string{ + "TENSORBOARD_URL": fmt.Sprintf("http://127.0.0.1:%d", port), + "TENSORBOARD_HOST_LOG_DIR": logDir, + }, nil +} + +// Teardown stops the TensorBoard sidecar. +func (t *TensorBoardPlugin) Teardown(ctx context.Context, taskID string) error { + t.mu.Lock() + sidecar, ok := t.sidecars[taskID] + if ok { + delete(t.sidecars, taskID) + } + t.mu.Unlock() + + if !ok { + return nil + } + + if sidecar.port != 0 { + t.opts.PortAllocator.Release(sidecar.port) + } + + if sidecar.containerID == "" { + return nil + } + + if err := t.podman.StopContainer(ctx, sidecar.containerID); err != nil && t.logger != nil { + t.logger.Warn("failed to stop tensorboard container", "task_id", taskID, "error", err) + } + if err := t.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && t.logger != nil { + t.logger.Warn("failed to remove tensorboard container", "task_id", taskID, "error", err) + } + + return nil +} + +// HealthCheck currently returns true; extend with HTTP checks later. +func (t *TensorBoardPlugin) HealthCheck(context.Context, trackingpkg.ToolConfig) bool { + return true +} diff --git a/internal/tracking/plugins/wandb.go b/internal/tracking/plugins/wandb.go new file mode 100644 index 0000000..d6b97a8 --- /dev/null +++ b/internal/tracking/plugins/wandb.go @@ -0,0 +1,69 @@ +package plugins + +import ( + "context" + "fmt" + + trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking" +) + +// WandbPlugin forwards credentials to the task container without provisioning a sidecar. +type WandbPlugin struct{} + +// NewWandbPlugin constructs a Wandb plugin. +func NewWandbPlugin() *WandbPlugin { + return &WandbPlugin{} +} + +// Name returns the plugin name. +func (w *WandbPlugin) Name() string { + return "wandb" +} + +// ProvisionSidecar validates configuration and returns environment variables. +func (w *WandbPlugin) ProvisionSidecar( + _ context.Context, + _ string, + config trackingpkg.ToolConfig, +) (map[string]string, error) { + if config.Mode == trackingpkg.ModeDisabled { + return nil, nil + } + + apiKey := trackingpkg.StringSetting(config.Settings, "api_key") + project := trackingpkg.StringSetting(config.Settings, "project") + entity := trackingpkg.StringSetting(config.Settings, "entity") + + if config.Mode == trackingpkg.ModeRemote && apiKey == "" { + return nil, fmt.Errorf("wandb remote mode requires api_key") + } + + env := map[string]string{} + if apiKey != "" { + env["WANDB_API_KEY"] = apiKey + } + if project != "" { + env["WANDB_PROJECT"] = project + } + if entity != "" { + env["WANDB_ENTITY"] = entity + } + + return env, nil +} + +// Teardown is a no-op for Wandb. +func (w *WandbPlugin) Teardown(context.Context, string) error { + return nil +} + +// HealthCheck ensures required params exist. +func (w *WandbPlugin) HealthCheck(_ context.Context, config trackingpkg.ToolConfig) bool { + if !config.Enabled { + return true + } + if config.Mode == trackingpkg.ModeRemote { + return trackingpkg.StringSetting(config.Settings, "api_key") != "" + } + return true +}