feat(tracking): add pluggable tracking backends and audit support

This commit is contained in:
Jeremie Fraeys 2026-01-05 12:33:57 -05:00
parent a8287f3087
commit dab680a60d
6 changed files with 931 additions and 0 deletions

171
internal/audit/audit.go Normal file
View file

@ -0,0 +1,171 @@
package audit
import (
"encoding/json"
"fmt"
"os"
"sync"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// EventType represents the type of audit event
type EventType string
const (
EventAuthAttempt EventType = "authentication_attempt"
EventAuthSuccess EventType = "authentication_success"
EventAuthFailure EventType = "authentication_failure"
EventJobQueued EventType = "job_queued"
EventJobStarted EventType = "job_started"
EventJobCompleted EventType = "job_completed"
EventJobFailed EventType = "job_failed"
EventJupyterStart EventType = "jupyter_start"
EventJupyterStop EventType = "jupyter_stop"
EventExperimentCreated EventType = "experiment_created"
EventExperimentDeleted EventType = "experiment_deleted"
)
// Event represents an audit log event
type Event struct {
Timestamp time.Time `json:"timestamp"`
EventType EventType `json:"event_type"`
UserID string `json:"user_id,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Resource string `json:"resource,omitempty"`
Action string `json:"action,omitempty"`
Success bool `json:"success"`
ErrorMsg string `json:"error,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// Logger handles audit logging
type Logger struct {
enabled bool
filePath string
file *os.File
mu sync.Mutex
logger *logging.Logger
}
// NewLogger creates a new audit logger
func NewLogger(enabled bool, filePath string, logger *logging.Logger) (*Logger, error) {
al := &Logger{
enabled: enabled,
filePath: filePath,
logger: logger,
}
if enabled && filePath != "" {
file, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return nil, fmt.Errorf("failed to open audit log file: %w", err)
}
al.file = file
}
return al, nil
}
// Log logs an audit event
func (al *Logger) Log(event Event) {
if !al.enabled {
return
}
event.Timestamp = time.Now().UTC()
al.mu.Lock()
defer al.mu.Unlock()
// Marshal to JSON
data, err := json.Marshal(event)
if err != nil {
if al.logger != nil {
al.logger.Error("failed to marshal audit event", "error", err)
}
return
}
// Write to file if configured
if al.file != nil {
_, err = al.file.Write(append(data, '\n'))
if err != nil && al.logger != nil {
al.logger.Error("failed to write audit event", "error", err)
}
}
// Also log via structured logger
if al.logger != nil {
al.logger.Info("audit_event",
"event_type", event.EventType,
"user_id", event.UserID,
"resource", event.Resource,
"success", event.Success,
)
}
}
// LogAuthAttempt logs an authentication attempt
func (al *Logger) LogAuthAttempt(userID, ipAddr string, success bool, errMsg string) {
eventType := EventAuthSuccess
if !success {
eventType = EventAuthFailure
}
al.Log(Event{
EventType: eventType,
UserID: userID,
IPAddress: ipAddr,
Success: success,
ErrorMsg: errMsg,
})
}
// LogJobOperation logs a job-related operation
func (al *Logger) LogJobOperation(
eventType EventType,
userID, jobID, ipAddr string,
success bool,
errMsg string,
) {
al.Log(Event{
EventType: eventType,
UserID: userID,
IPAddress: ipAddr,
Resource: jobID,
Action: "job_operation",
Success: success,
ErrorMsg: errMsg,
})
}
// LogJupyterOperation logs a Jupyter service operation
func (al *Logger) LogJupyterOperation(
eventType EventType,
userID, serviceID, ipAddr string,
success bool,
errMsg string,
) {
al.Log(Event{
EventType: eventType,
UserID: userID,
IPAddress: ipAddr,
Resource: serviceID,
Action: "jupyter_operation",
Success: success,
ErrorMsg: errMsg,
})
}
// Close closes the audit logger
func (al *Logger) Close() error {
al.mu.Lock()
defer al.mu.Unlock()
if al.file != nil {
return al.file.Close()
}
return nil
}

View file

@ -0,0 +1,116 @@
package factory
import (
"fmt"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/tracking"
"github.com/jfraeys/fetch_ml/internal/tracking/plugins"
)
// PluginConfig represents the configuration for a single plugin.
type PluginConfig struct {
Enabled bool `toml:"enabled" yaml:"enabled"`
Image string `toml:"image" yaml:"image"`
Mode string `toml:"mode" yaml:"mode"`
LogBasePath string `toml:"log_base_path" yaml:"log_base_path"`
ArtifactPath string `toml:"artifact_path" yaml:"artifact_path"`
Settings map[string]any `toml:"settings" yaml:"settings"`
}
// PluginFactory is a function that creates a Plugin instance.
type PluginFactory func(
logger *logging.Logger,
podman *container.PodmanManager,
cfg PluginConfig,
) (tracking.Plugin, error)
// PluginLoader uses dependency injection to load plugins.
type PluginLoader struct {
logger *logging.Logger
podman *container.PodmanManager
factories map[string]PluginFactory
}
// NewPluginLoader creates a new PluginLoader.
func NewPluginLoader(logger *logging.Logger, podman *container.PodmanManager) *PluginLoader {
loader := &PluginLoader{
logger: logger,
podman: podman,
factories: make(map[string]PluginFactory),
}
// Register default factories
loader.RegisterFactory("mlflow", createMLflowPlugin)
loader.RegisterFactory("tensorboard", createTensorBoardPlugin)
loader.RegisterFactory("wandb", createWandbPlugin)
return loader
}
// RegisterFactory allows external packages to register new plugin types (Marketplace ready).
func (l *PluginLoader) RegisterFactory(name string, factory PluginFactory) {
l.factories[name] = factory
}
// LoadPlugins loads plugins from the provided configuration map and registers them.
func (l *PluginLoader) LoadPlugins(
plugins map[string]PluginConfig,
registry *tracking.Registry,
) error {
for name, pluginCfg := range plugins {
if !pluginCfg.Enabled {
continue
}
factory, ok := l.factories[name]
if !ok {
l.logger.Warn("unknown plugin type", "name", name)
continue
}
plugin, err := factory(l.logger, l.podman, pluginCfg)
if err != nil {
return fmt.Errorf("failed to create plugin %s: %w", name, err)
}
registry.Register(plugin)
l.logger.Info("plugin loaded", "name", name, "mode", pluginCfg.Mode)
}
return nil
}
// Factory Implementations
func createMLflowPlugin(
logger *logging.Logger,
podman *container.PodmanManager,
cfg PluginConfig,
) (tracking.Plugin, error) {
opts := plugins.MLflowOptions{
Image: cfg.Image,
ArtifactBasePath: cfg.ArtifactPath,
}
return plugins.NewMLflowPlugin(logger, podman, opts)
}
func createTensorBoardPlugin(
logger *logging.Logger,
podman *container.PodmanManager,
cfg PluginConfig,
) (tracking.Plugin, error) {
opts := plugins.TensorBoardOptions{
Image: cfg.Image,
LogBasePath: cfg.LogBasePath,
}
return plugins.NewTensorBoardPlugin(logger, podman, opts)
}
func createWandbPlugin(
logger *logging.Logger,
_ *container.PodmanManager,
_ PluginConfig,
) (tracking.Plugin, error) {
return plugins.NewWandbPlugin(), nil
}

209
internal/tracking/plugin.go Normal file
View file

@ -0,0 +1,209 @@
package tracking
import (
"context"
"fmt"
"sync"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// ToolMode represents the provisioning mode for a tracking tool.
type ToolMode string
const (
// ModeSidecar provisions the tool as a sidecar container.
ModeSidecar ToolMode = "sidecar"
// ModeRemote points to a remotely managed instance (no local provisioning).
ModeRemote ToolMode = "remote"
// ModeDisabled skips provisioning entirely.
ModeDisabled ToolMode = "disabled"
)
// ToolConfig specifies how a plugin should be provisioned for a task.
type ToolConfig struct {
Enabled bool
Mode ToolMode
Settings map[string]any
}
// Plugin defines the behaviour every tracking integration must implement.
type Plugin interface {
Name() string
ProvisionSidecar(ctx context.Context, taskID string, config ToolConfig) (map[string]string, error)
Teardown(ctx context.Context, taskID string) error
HealthCheck(ctx context.Context, config ToolConfig) bool
}
// Registry keeps track of registered plugins and their lifecycle per task.
type Registry struct {
logger *logging.Logger
mu sync.Mutex
plugins map[string]Plugin
active map[string][]string
}
// NewRegistry returns a new plugin registry.
func NewRegistry(logger *logging.Logger) *Registry {
return &Registry{
logger: logger,
plugins: make(map[string]Plugin),
active: make(map[string][]string),
}
}
// Register adds a plugin to the registry.
func (r *Registry) Register(p Plugin) {
r.mu.Lock()
defer r.mu.Unlock()
r.plugins[p.Name()] = p
}
// Get retrieves a plugin by name.
func (r *Registry) Get(name string) (Plugin, bool) {
r.mu.Lock()
defer r.mu.Unlock()
p, ok := r.plugins[name]
return p, ok
}
// ProvisionAll provisions configured plugins for a task and merges their environment variables.
func (r *Registry) ProvisionAll(
ctx context.Context,
taskID string,
configs map[string]ToolConfig,
) (map[string]string, error) {
if len(configs) == 0 {
return nil, nil
}
env := make(map[string]string)
var provisioned []string
for name, cfg := range configs {
if !cfg.Enabled || cfg.Mode == ModeDisabled {
continue
}
plugin, ok := r.Get(name)
if !ok {
return nil, fmt.Errorf("tracking plugin %s not registered", name)
}
settingsEnv, err := plugin.ProvisionSidecar(ctx, taskID, cfg)
if err != nil {
r.rollback(ctx, taskID, provisioned)
return nil, fmt.Errorf("failed to provision %s: %w", name, err)
}
for k, v := range settingsEnv {
env[k] = v
}
if cfg.Mode == ModeSidecar {
provisioned = append(provisioned, name)
}
}
if len(provisioned) > 0 {
r.mu.Lock()
r.active[taskID] = append(r.active[taskID], provisioned...)
r.mu.Unlock()
}
return env, nil
}
// TeardownAll stops every plugin that was provisioned for a task.
func (r *Registry) TeardownAll(ctx context.Context, taskID string) {
r.mu.Lock()
plugins := r.active[taskID]
delete(r.active, taskID)
r.mu.Unlock()
for _, name := range plugins {
plugin, ok := r.Get(name)
if !ok {
continue
}
if err := plugin.Teardown(ctx, taskID); err != nil && r.logger != nil {
r.logger.Warn("tracking teardown failed", "plugin", name, "task_id", taskID, "error", err)
}
}
}
func (r *Registry) rollback(ctx context.Context, taskID string, provisioned []string) {
for i := len(provisioned) - 1; i >= 0; i-- {
name := provisioned[i]
plugin, ok := r.Get(name)
if !ok {
continue
}
if err := plugin.Teardown(ctx, taskID); err != nil && r.logger != nil {
r.logger.Warn("rollback failed", "plugin", name, "task_id", taskID, "error", err)
}
}
}
// PortAllocator manages dynamic port assignments for sidecars.
type PortAllocator struct {
mu sync.Mutex
start int
end int
next int
used map[int]bool
}
// NewPortAllocator creates a new allocator for a port range.
func NewPortAllocator(start, end int) *PortAllocator {
if start <= 0 || end <= 0 || end <= start {
start = 5500
end = 5600
}
return &PortAllocator{
start: start,
end: end,
next: start,
used: make(map[int]bool),
}
}
// Allocate reserves the next available port.
func (p *PortAllocator) Allocate() (int, error) {
p.mu.Lock()
defer p.mu.Unlock()
for i := 0; i < p.end-p.start; i++ {
port := p.next
p.next++
if p.next >= p.end {
p.next = p.start
}
if !p.used[port] {
p.used[port] = true
return port, nil
}
}
return 0, fmt.Errorf("no ports available in range %d-%d", p.start, p.end)
}
// Release frees a previously allocated port.
func (p *PortAllocator) Release(port int) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.used, port)
}
// StringSetting safely reads a string from plugin settings.
func StringSetting(settings map[string]any, key string) string {
if settings == nil {
return ""
}
if v, ok := settings[key]; ok {
if str, ok := v.(string); ok {
return str
}
}
return ""
}

View file

@ -0,0 +1,196 @@
package plugins
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/logging"
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
)
// MLflowOptions configures the MLflow plugin.
type MLflowOptions struct {
Image string
ArtifactBasePath string
DefaultTrackingURI string
PortAllocator *trackingpkg.PortAllocator
}
type mlflowSidecar struct {
containerID string
port int
}
// MLflowPlugin provisions MLflow tracking servers per task.
type MLflowPlugin struct {
logger *logging.Logger
podman *container.PodmanManager
opts MLflowOptions
mu sync.Mutex
sidecars map[string]*mlflowSidecar
}
// NewMLflowPlugin creates a new MLflow plugin instance.
func NewMLflowPlugin(
logger *logging.Logger,
podman *container.PodmanManager,
opts MLflowOptions,
) (*MLflowPlugin, error) {
if podman == nil {
return nil, fmt.Errorf("podman manager is required for MLflow plugin")
}
if opts.Image == "" {
opts.Image = "ghcr.io/mlflow/mlflow:v2.16.1"
}
if opts.ArtifactBasePath == "" {
return nil, fmt.Errorf("artifact base path is required for MLflow plugin")
}
if opts.PortAllocator == nil {
opts.PortAllocator = trackingpkg.NewPortAllocator(5500, 5700)
}
return &MLflowPlugin{
logger: logger,
podman: podman,
opts: opts,
sidecars: make(map[string]*mlflowSidecar),
}, nil
}
// Name returns the plugin name.
func (m *MLflowPlugin) Name() string {
return "mlflow"
}
// ProvisionSidecar starts an MLflow sidecar or returns remote env vars.
func (m *MLflowPlugin) ProvisionSidecar(
ctx context.Context,
taskID string,
config trackingpkg.ToolConfig,
) (map[string]string, error) {
switch config.Mode {
case trackingpkg.ModeRemote:
uri := trackingpkg.StringSetting(config.Settings, "tracking_uri")
if uri == "" {
uri = m.opts.DefaultTrackingURI
}
if uri == "" {
return nil, fmt.Errorf("mlflow remote mode requires tracking_uri")
}
return map[string]string{
"MLFLOW_TRACKING_URI": uri,
}, nil
case trackingpkg.ModeDisabled:
return nil, nil
default:
return m.provisionSidecar(ctx, taskID, config)
}
}
func (m *MLflowPlugin) provisionSidecar(
ctx context.Context,
taskID string,
config trackingpkg.ToolConfig,
) (map[string]string, error) {
jobName := trackingpkg.StringSetting(config.Settings, "job_name")
if jobName == "" {
jobName = taskID
}
taskDir := filepath.Join(m.opts.ArtifactBasePath, jobName)
if err := os.MkdirAll(taskDir, 0750); err != nil {
return nil, fmt.Errorf("failed to create MLflow artifact dir: %w", err)
}
port, err := m.opts.PortAllocator.Allocate()
if err != nil {
return nil, err
}
containerName := fmt.Sprintf("mlflow-%s", taskID)
cmd := []string{
"mlflow", "server",
"--host=0.0.0.0",
"--port=5000",
"--backend-store-uri=file:///mlruns",
"--default-artifact-root=file:///mlruns",
}
cfg := &container.ContainerConfig{
Name: containerName,
Image: m.opts.Image,
Command: cmd,
Env: map[string]string{
"MLFLOW_BACKEND_STORE_URI": "file:///mlruns",
"MLFLOW_ARTIFACT_URI": "file:///mlruns",
},
Volumes: map[string]string{
taskDir: "/mlruns",
},
Ports: map[int]int{
port: 5000,
},
Network: container.NetworkConfig{
AllowNetwork: true,
},
}
containerID, err := m.podman.StartContainer(ctx, cfg)
if err != nil {
m.opts.PortAllocator.Release(port)
return nil, err
}
m.mu.Lock()
m.sidecars[taskID] = &mlflowSidecar{
containerID: containerID,
port: port,
}
m.mu.Unlock()
return map[string]string{
"MLFLOW_TRACKING_URI": fmt.Sprintf("http://127.0.0.1:%d", port),
}, nil
}
// Teardown stops the MLflow sidecar for a task.
func (m *MLflowPlugin) Teardown(ctx context.Context, taskID string) error {
m.mu.Lock()
sidecar, ok := m.sidecars[taskID]
if ok {
delete(m.sidecars, taskID)
}
m.mu.Unlock()
if !ok {
return nil
}
if sidecar.port != 0 {
m.opts.PortAllocator.Release(sidecar.port)
}
if sidecar.containerID == "" {
return nil
}
if err := m.podman.StopContainer(ctx, sidecar.containerID); err != nil && m.logger != nil {
m.logger.Warn("failed to stop MLflow container", "task_id", taskID, "error", err)
}
if err := m.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && m.logger != nil {
m.logger.Warn("failed to remove MLflow container", "task_id", taskID, "error", err)
}
return nil
}
// HealthCheck verifies plugin readiness (basic for now).
func (m *MLflowPlugin) HealthCheck(_ context.Context, _ trackingpkg.ToolConfig) bool {
// Future: Perform HTTP checks against tracking URI.
return true
}

View file

@ -0,0 +1,170 @@
package plugins
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/logging"
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
)
// TensorBoardOptions configure the TensorBoard plugin.
type TensorBoardOptions struct {
Image string
LogBasePath string
PortAllocator *trackingpkg.PortAllocator
}
type tensorboardSidecar struct {
containerID string
port int
}
// TensorBoardPlugin exposes training logs through TensorBoard.
type TensorBoardPlugin struct {
logger *logging.Logger
podman *container.PodmanManager
opts TensorBoardOptions
mu sync.Mutex
sidecars map[string]*tensorboardSidecar
}
// NewTensorBoardPlugin constructs a TensorBoard plugin instance.
func NewTensorBoardPlugin(
logger *logging.Logger,
podman *container.PodmanManager,
opts TensorBoardOptions,
) (*TensorBoardPlugin, error) {
if podman == nil {
return nil, fmt.Errorf("podman manager is required for TensorBoard plugin")
}
if opts.Image == "" {
opts.Image = "tensorflow/tensorflow:2.17.0"
}
if opts.LogBasePath == "" {
return nil, fmt.Errorf("log base path is required for TensorBoard plugin")
}
if opts.PortAllocator == nil {
opts.PortAllocator = trackingpkg.NewPortAllocator(5700, 5900)
}
return &TensorBoardPlugin{
logger: logger,
podman: podman,
opts: opts,
sidecars: make(map[string]*tensorboardSidecar),
}, nil
}
// Name returns the plugin name.
func (t *TensorBoardPlugin) Name() string {
return "tensorboard"
}
// ProvisionSidecar starts TensorBoard to tail a task's log directory.
func (t *TensorBoardPlugin) ProvisionSidecar(
ctx context.Context,
taskID string,
config trackingpkg.ToolConfig,
) (map[string]string, error) {
if config.Mode == trackingpkg.ModeRemote || config.Mode == trackingpkg.ModeDisabled {
return nil, nil
}
jobName := trackingpkg.StringSetting(config.Settings, "job_name")
if jobName == "" {
jobName = taskID
}
logDir := filepath.Join(t.opts.LogBasePath, jobName)
if err := os.MkdirAll(logDir, 0750); err != nil {
return nil, fmt.Errorf("failed to ensure tensorboard log dir: %w", err)
}
port, err := t.opts.PortAllocator.Allocate()
if err != nil {
return nil, err
}
containerName := fmt.Sprintf("tensorboard-%s", taskID)
cmd := []string{
"tensorboard",
"--logdir", "/logs",
"--host", "0.0.0.0",
"--port", "6006",
}
cfg := &container.ContainerConfig{
Name: containerName,
Image: t.opts.Image,
Command: cmd,
Volumes: map[string]string{
logDir: "/logs:ro",
},
Ports: map[int]int{
port: 6006,
},
Network: container.NetworkConfig{
AllowNetwork: true,
},
}
containerID, err := t.podman.StartContainer(ctx, cfg)
if err != nil {
t.opts.PortAllocator.Release(port)
return nil, err
}
t.mu.Lock()
t.sidecars[taskID] = &tensorboardSidecar{
containerID: containerID,
port: port,
}
t.mu.Unlock()
return map[string]string{
"TENSORBOARD_URL": fmt.Sprintf("http://127.0.0.1:%d", port),
"TENSORBOARD_HOST_LOG_DIR": logDir,
}, nil
}
// Teardown stops the TensorBoard sidecar.
func (t *TensorBoardPlugin) Teardown(ctx context.Context, taskID string) error {
t.mu.Lock()
sidecar, ok := t.sidecars[taskID]
if ok {
delete(t.sidecars, taskID)
}
t.mu.Unlock()
if !ok {
return nil
}
if sidecar.port != 0 {
t.opts.PortAllocator.Release(sidecar.port)
}
if sidecar.containerID == "" {
return nil
}
if err := t.podman.StopContainer(ctx, sidecar.containerID); err != nil && t.logger != nil {
t.logger.Warn("failed to stop tensorboard container", "task_id", taskID, "error", err)
}
if err := t.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && t.logger != nil {
t.logger.Warn("failed to remove tensorboard container", "task_id", taskID, "error", err)
}
return nil
}
// HealthCheck currently returns true; extend with HTTP checks later.
func (t *TensorBoardPlugin) HealthCheck(context.Context, trackingpkg.ToolConfig) bool {
return true
}

View file

@ -0,0 +1,69 @@
package plugins
import (
"context"
"fmt"
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
)
// WandbPlugin forwards credentials to the task container without provisioning a sidecar.
type WandbPlugin struct{}
// NewWandbPlugin constructs a Wandb plugin.
func NewWandbPlugin() *WandbPlugin {
return &WandbPlugin{}
}
// Name returns the plugin name.
func (w *WandbPlugin) Name() string {
return "wandb"
}
// ProvisionSidecar validates configuration and returns environment variables.
func (w *WandbPlugin) ProvisionSidecar(
_ context.Context,
_ string,
config trackingpkg.ToolConfig,
) (map[string]string, error) {
if config.Mode == trackingpkg.ModeDisabled {
return nil, nil
}
apiKey := trackingpkg.StringSetting(config.Settings, "api_key")
project := trackingpkg.StringSetting(config.Settings, "project")
entity := trackingpkg.StringSetting(config.Settings, "entity")
if config.Mode == trackingpkg.ModeRemote && apiKey == "" {
return nil, fmt.Errorf("wandb remote mode requires api_key")
}
env := map[string]string{}
if apiKey != "" {
env["WANDB_API_KEY"] = apiKey
}
if project != "" {
env["WANDB_PROJECT"] = project
}
if entity != "" {
env["WANDB_ENTITY"] = entity
}
return env, nil
}
// Teardown is a no-op for Wandb.
func (w *WandbPlugin) Teardown(context.Context, string) error {
return nil
}
// HealthCheck ensures required params exist.
func (w *WandbPlugin) HealthCheck(_ context.Context, config trackingpkg.ToolConfig) bool {
if !config.Enabled {
return true
}
if config.Mode == trackingpkg.ModeRemote {
return trackingpkg.StringSetting(config.Settings, "api_key") != ""
}
return true
}