feat(tracking): add pluggable tracking backends and audit support
This commit is contained in:
parent
a8287f3087
commit
dab680a60d
6 changed files with 931 additions and 0 deletions
171
internal/audit/audit.go
Normal file
171
internal/audit/audit.go
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// EventType represents the type of audit event
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventAuthAttempt EventType = "authentication_attempt"
|
||||
EventAuthSuccess EventType = "authentication_success"
|
||||
EventAuthFailure EventType = "authentication_failure"
|
||||
EventJobQueued EventType = "job_queued"
|
||||
EventJobStarted EventType = "job_started"
|
||||
EventJobCompleted EventType = "job_completed"
|
||||
EventJobFailed EventType = "job_failed"
|
||||
EventJupyterStart EventType = "jupyter_start"
|
||||
EventJupyterStop EventType = "jupyter_stop"
|
||||
EventExperimentCreated EventType = "experiment_created"
|
||||
EventExperimentDeleted EventType = "experiment_deleted"
|
||||
)
|
||||
|
||||
// Event represents an audit log event
|
||||
type Event struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
EventType EventType `json:"event_type"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Resource string `json:"resource,omitempty"`
|
||||
Action string `json:"action,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
ErrorMsg string `json:"error,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Logger handles audit logging
|
||||
type Logger struct {
|
||||
enabled bool
|
||||
filePath string
|
||||
file *os.File
|
||||
mu sync.Mutex
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new audit logger
|
||||
func NewLogger(enabled bool, filePath string, logger *logging.Logger) (*Logger, error) {
|
||||
al := &Logger{
|
||||
enabled: enabled,
|
||||
filePath: filePath,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if enabled && filePath != "" {
|
||||
file, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open audit log file: %w", err)
|
||||
}
|
||||
al.file = file
|
||||
}
|
||||
|
||||
return al, nil
|
||||
}
|
||||
|
||||
// Log logs an audit event
|
||||
func (al *Logger) Log(event Event) {
|
||||
if !al.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
event.Timestamp = time.Now().UTC()
|
||||
|
||||
al.mu.Lock()
|
||||
defer al.mu.Unlock()
|
||||
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
if al.logger != nil {
|
||||
al.logger.Error("failed to marshal audit event", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Write to file if configured
|
||||
if al.file != nil {
|
||||
_, err = al.file.Write(append(data, '\n'))
|
||||
if err != nil && al.logger != nil {
|
||||
al.logger.Error("failed to write audit event", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Also log via structured logger
|
||||
if al.logger != nil {
|
||||
al.logger.Info("audit_event",
|
||||
"event_type", event.EventType,
|
||||
"user_id", event.UserID,
|
||||
"resource", event.Resource,
|
||||
"success", event.Success,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// LogAuthAttempt logs an authentication attempt
|
||||
func (al *Logger) LogAuthAttempt(userID, ipAddr string, success bool, errMsg string) {
|
||||
eventType := EventAuthSuccess
|
||||
if !success {
|
||||
eventType = EventAuthFailure
|
||||
}
|
||||
|
||||
al.Log(Event{
|
||||
EventType: eventType,
|
||||
UserID: userID,
|
||||
IPAddress: ipAddr,
|
||||
Success: success,
|
||||
ErrorMsg: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
// LogJobOperation logs a job-related operation
|
||||
func (al *Logger) LogJobOperation(
|
||||
eventType EventType,
|
||||
userID, jobID, ipAddr string,
|
||||
success bool,
|
||||
errMsg string,
|
||||
) {
|
||||
al.Log(Event{
|
||||
EventType: eventType,
|
||||
UserID: userID,
|
||||
IPAddress: ipAddr,
|
||||
Resource: jobID,
|
||||
Action: "job_operation",
|
||||
Success: success,
|
||||
ErrorMsg: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
// LogJupyterOperation logs a Jupyter service operation
|
||||
func (al *Logger) LogJupyterOperation(
|
||||
eventType EventType,
|
||||
userID, serviceID, ipAddr string,
|
||||
success bool,
|
||||
errMsg string,
|
||||
) {
|
||||
al.Log(Event{
|
||||
EventType: eventType,
|
||||
UserID: userID,
|
||||
IPAddress: ipAddr,
|
||||
Resource: serviceID,
|
||||
Action: "jupyter_operation",
|
||||
Success: success,
|
||||
ErrorMsg: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
// Close closes the audit logger
|
||||
func (al *Logger) Close() error {
|
||||
al.mu.Lock()
|
||||
defer al.mu.Unlock()
|
||||
|
||||
if al.file != nil {
|
||||
return al.file.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
internal/tracking/factory/loader.go
Normal file
116
internal/tracking/factory/loader.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package factory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
"github.com/jfraeys/fetch_ml/internal/tracking/plugins"
|
||||
)
|
||||
|
||||
// PluginConfig represents the configuration for a single plugin.
|
||||
type PluginConfig struct {
|
||||
Enabled bool `toml:"enabled" yaml:"enabled"`
|
||||
Image string `toml:"image" yaml:"image"`
|
||||
Mode string `toml:"mode" yaml:"mode"`
|
||||
LogBasePath string `toml:"log_base_path" yaml:"log_base_path"`
|
||||
ArtifactPath string `toml:"artifact_path" yaml:"artifact_path"`
|
||||
Settings map[string]any `toml:"settings" yaml:"settings"`
|
||||
}
|
||||
|
||||
// PluginFactory is a function that creates a Plugin instance.
|
||||
type PluginFactory func(
|
||||
logger *logging.Logger,
|
||||
podman *container.PodmanManager,
|
||||
cfg PluginConfig,
|
||||
) (tracking.Plugin, error)
|
||||
|
||||
// PluginLoader uses dependency injection to load plugins.
|
||||
type PluginLoader struct {
|
||||
logger *logging.Logger
|
||||
podman *container.PodmanManager
|
||||
factories map[string]PluginFactory
|
||||
}
|
||||
|
||||
// NewPluginLoader creates a new PluginLoader.
|
||||
func NewPluginLoader(logger *logging.Logger, podman *container.PodmanManager) *PluginLoader {
|
||||
loader := &PluginLoader{
|
||||
logger: logger,
|
||||
podman: podman,
|
||||
factories: make(map[string]PluginFactory),
|
||||
}
|
||||
|
||||
// Register default factories
|
||||
loader.RegisterFactory("mlflow", createMLflowPlugin)
|
||||
loader.RegisterFactory("tensorboard", createTensorBoardPlugin)
|
||||
loader.RegisterFactory("wandb", createWandbPlugin)
|
||||
|
||||
return loader
|
||||
}
|
||||
|
||||
// RegisterFactory allows external packages to register new plugin types (Marketplace ready).
|
||||
func (l *PluginLoader) RegisterFactory(name string, factory PluginFactory) {
|
||||
l.factories[name] = factory
|
||||
}
|
||||
|
||||
// LoadPlugins loads plugins from the provided configuration map and registers them.
|
||||
func (l *PluginLoader) LoadPlugins(
|
||||
plugins map[string]PluginConfig,
|
||||
registry *tracking.Registry,
|
||||
) error {
|
||||
for name, pluginCfg := range plugins {
|
||||
if !pluginCfg.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
factory, ok := l.factories[name]
|
||||
if !ok {
|
||||
l.logger.Warn("unknown plugin type", "name", name)
|
||||
continue
|
||||
}
|
||||
|
||||
plugin, err := factory(l.logger, l.podman, pluginCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create plugin %s: %w", name, err)
|
||||
}
|
||||
|
||||
registry.Register(plugin)
|
||||
l.logger.Info("plugin loaded", "name", name, "mode", pluginCfg.Mode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Factory Implementations
|
||||
func createMLflowPlugin(
|
||||
logger *logging.Logger,
|
||||
podman *container.PodmanManager,
|
||||
cfg PluginConfig,
|
||||
) (tracking.Plugin, error) {
|
||||
opts := plugins.MLflowOptions{
|
||||
Image: cfg.Image,
|
||||
ArtifactBasePath: cfg.ArtifactPath,
|
||||
}
|
||||
return plugins.NewMLflowPlugin(logger, podman, opts)
|
||||
}
|
||||
|
||||
func createTensorBoardPlugin(
|
||||
logger *logging.Logger,
|
||||
podman *container.PodmanManager,
|
||||
cfg PluginConfig,
|
||||
) (tracking.Plugin, error) {
|
||||
opts := plugins.TensorBoardOptions{
|
||||
Image: cfg.Image,
|
||||
LogBasePath: cfg.LogBasePath,
|
||||
}
|
||||
return plugins.NewTensorBoardPlugin(logger, podman, opts)
|
||||
}
|
||||
|
||||
func createWandbPlugin(
|
||||
logger *logging.Logger,
|
||||
_ *container.PodmanManager,
|
||||
_ PluginConfig,
|
||||
) (tracking.Plugin, error) {
|
||||
return plugins.NewWandbPlugin(), nil
|
||||
}
|
||||
209
internal/tracking/plugin.go
Normal file
209
internal/tracking/plugin.go
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
package tracking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// ToolMode represents the provisioning mode for a tracking tool.
|
||||
type ToolMode string
|
||||
|
||||
const (
|
||||
// ModeSidecar provisions the tool as a sidecar container.
|
||||
ModeSidecar ToolMode = "sidecar"
|
||||
// ModeRemote points to a remotely managed instance (no local provisioning).
|
||||
ModeRemote ToolMode = "remote"
|
||||
// ModeDisabled skips provisioning entirely.
|
||||
ModeDisabled ToolMode = "disabled"
|
||||
)
|
||||
|
||||
// ToolConfig specifies how a plugin should be provisioned for a task.
|
||||
type ToolConfig struct {
|
||||
Enabled bool
|
||||
Mode ToolMode
|
||||
Settings map[string]any
|
||||
}
|
||||
|
||||
// Plugin defines the behaviour every tracking integration must implement.
|
||||
type Plugin interface {
|
||||
Name() string
|
||||
ProvisionSidecar(ctx context.Context, taskID string, config ToolConfig) (map[string]string, error)
|
||||
Teardown(ctx context.Context, taskID string) error
|
||||
HealthCheck(ctx context.Context, config ToolConfig) bool
|
||||
}
|
||||
|
||||
// Registry keeps track of registered plugins and their lifecycle per task.
|
||||
type Registry struct {
|
||||
logger *logging.Logger
|
||||
mu sync.Mutex
|
||||
plugins map[string]Plugin
|
||||
active map[string][]string
|
||||
}
|
||||
|
||||
// NewRegistry returns a new plugin registry.
|
||||
func NewRegistry(logger *logging.Logger) *Registry {
|
||||
return &Registry{
|
||||
logger: logger,
|
||||
plugins: make(map[string]Plugin),
|
||||
active: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a plugin to the registry.
|
||||
func (r *Registry) Register(p Plugin) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.plugins[p.Name()] = p
|
||||
}
|
||||
|
||||
// Get retrieves a plugin by name.
|
||||
func (r *Registry) Get(name string) (Plugin, bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
p, ok := r.plugins[name]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// ProvisionAll provisions configured plugins for a task and merges their environment variables.
|
||||
func (r *Registry) ProvisionAll(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
configs map[string]ToolConfig,
|
||||
) (map[string]string, error) {
|
||||
if len(configs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
env := make(map[string]string)
|
||||
var provisioned []string
|
||||
|
||||
for name, cfg := range configs {
|
||||
if !cfg.Enabled || cfg.Mode == ModeDisabled {
|
||||
continue
|
||||
}
|
||||
|
||||
plugin, ok := r.Get(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tracking plugin %s not registered", name)
|
||||
}
|
||||
|
||||
settingsEnv, err := plugin.ProvisionSidecar(ctx, taskID, cfg)
|
||||
if err != nil {
|
||||
r.rollback(ctx, taskID, provisioned)
|
||||
return nil, fmt.Errorf("failed to provision %s: %w", name, err)
|
||||
}
|
||||
|
||||
for k, v := range settingsEnv {
|
||||
env[k] = v
|
||||
}
|
||||
|
||||
if cfg.Mode == ModeSidecar {
|
||||
provisioned = append(provisioned, name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(provisioned) > 0 {
|
||||
r.mu.Lock()
|
||||
r.active[taskID] = append(r.active[taskID], provisioned...)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
return env, nil
|
||||
}
|
||||
|
||||
// TeardownAll stops every plugin that was provisioned for a task.
|
||||
func (r *Registry) TeardownAll(ctx context.Context, taskID string) {
|
||||
r.mu.Lock()
|
||||
plugins := r.active[taskID]
|
||||
delete(r.active, taskID)
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, name := range plugins {
|
||||
plugin, ok := r.Get(name)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := plugin.Teardown(ctx, taskID); err != nil && r.logger != nil {
|
||||
r.logger.Warn("tracking teardown failed", "plugin", name, "task_id", taskID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) rollback(ctx context.Context, taskID string, provisioned []string) {
|
||||
for i := len(provisioned) - 1; i >= 0; i-- {
|
||||
name := provisioned[i]
|
||||
plugin, ok := r.Get(name)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := plugin.Teardown(ctx, taskID); err != nil && r.logger != nil {
|
||||
r.logger.Warn("rollback failed", "plugin", name, "task_id", taskID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PortAllocator manages dynamic port assignments for sidecars.
|
||||
type PortAllocator struct {
|
||||
mu sync.Mutex
|
||||
start int
|
||||
end int
|
||||
next int
|
||||
used map[int]bool
|
||||
}
|
||||
|
||||
// NewPortAllocator creates a new allocator for a port range.
|
||||
func NewPortAllocator(start, end int) *PortAllocator {
|
||||
if start <= 0 || end <= 0 || end <= start {
|
||||
start = 5500
|
||||
end = 5600
|
||||
}
|
||||
return &PortAllocator{
|
||||
start: start,
|
||||
end: end,
|
||||
next: start,
|
||||
used: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate reserves the next available port.
|
||||
func (p *PortAllocator) Allocate() (int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for i := 0; i < p.end-p.start; i++ {
|
||||
port := p.next
|
||||
p.next++
|
||||
if p.next >= p.end {
|
||||
p.next = p.start
|
||||
}
|
||||
if !p.used[port] {
|
||||
p.used[port] = true
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no ports available in range %d-%d", p.start, p.end)
|
||||
}
|
||||
|
||||
// Release frees a previously allocated port.
|
||||
func (p *PortAllocator) Release(port int) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
delete(p.used, port)
|
||||
}
|
||||
|
||||
// StringSetting safely reads a string from plugin settings.
|
||||
func StringSetting(settings map[string]any, key string) string {
|
||||
if settings == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := settings[key]; ok {
|
||||
if str, ok := v.(string); ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
196
internal/tracking/plugins/mlflow.go
Normal file
196
internal/tracking/plugins/mlflow.go
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
package plugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
)
|
||||
|
||||
// MLflowOptions configures the MLflow plugin.
|
||||
type MLflowOptions struct {
|
||||
Image string
|
||||
ArtifactBasePath string
|
||||
DefaultTrackingURI string
|
||||
PortAllocator *trackingpkg.PortAllocator
|
||||
}
|
||||
|
||||
type mlflowSidecar struct {
|
||||
containerID string
|
||||
port int
|
||||
}
|
||||
|
||||
// MLflowPlugin provisions MLflow tracking servers per task.
|
||||
type MLflowPlugin struct {
|
||||
logger *logging.Logger
|
||||
podman *container.PodmanManager
|
||||
opts MLflowOptions
|
||||
|
||||
mu sync.Mutex
|
||||
sidecars map[string]*mlflowSidecar
|
||||
}
|
||||
|
||||
// NewMLflowPlugin creates a new MLflow plugin instance.
|
||||
func NewMLflowPlugin(
|
||||
logger *logging.Logger,
|
||||
podman *container.PodmanManager,
|
||||
opts MLflowOptions,
|
||||
) (*MLflowPlugin, error) {
|
||||
if podman == nil {
|
||||
return nil, fmt.Errorf("podman manager is required for MLflow plugin")
|
||||
}
|
||||
if opts.Image == "" {
|
||||
opts.Image = "ghcr.io/mlflow/mlflow:v2.16.1"
|
||||
}
|
||||
if opts.ArtifactBasePath == "" {
|
||||
return nil, fmt.Errorf("artifact base path is required for MLflow plugin")
|
||||
}
|
||||
if opts.PortAllocator == nil {
|
||||
opts.PortAllocator = trackingpkg.NewPortAllocator(5500, 5700)
|
||||
}
|
||||
|
||||
return &MLflowPlugin{
|
||||
logger: logger,
|
||||
podman: podman,
|
||||
opts: opts,
|
||||
sidecars: make(map[string]*mlflowSidecar),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the plugin name.
|
||||
func (m *MLflowPlugin) Name() string {
|
||||
return "mlflow"
|
||||
}
|
||||
|
||||
// ProvisionSidecar starts an MLflow sidecar or returns remote env vars.
|
||||
func (m *MLflowPlugin) ProvisionSidecar(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
config trackingpkg.ToolConfig,
|
||||
) (map[string]string, error) {
|
||||
switch config.Mode {
|
||||
case trackingpkg.ModeRemote:
|
||||
uri := trackingpkg.StringSetting(config.Settings, "tracking_uri")
|
||||
if uri == "" {
|
||||
uri = m.opts.DefaultTrackingURI
|
||||
}
|
||||
if uri == "" {
|
||||
return nil, fmt.Errorf("mlflow remote mode requires tracking_uri")
|
||||
}
|
||||
return map[string]string{
|
||||
"MLFLOW_TRACKING_URI": uri,
|
||||
}, nil
|
||||
case trackingpkg.ModeDisabled:
|
||||
return nil, nil
|
||||
default:
|
||||
return m.provisionSidecar(ctx, taskID, config)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MLflowPlugin) provisionSidecar(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
config trackingpkg.ToolConfig,
|
||||
) (map[string]string, error) {
|
||||
jobName := trackingpkg.StringSetting(config.Settings, "job_name")
|
||||
if jobName == "" {
|
||||
jobName = taskID
|
||||
}
|
||||
|
||||
taskDir := filepath.Join(m.opts.ArtifactBasePath, jobName)
|
||||
if err := os.MkdirAll(taskDir, 0750); err != nil {
|
||||
return nil, fmt.Errorf("failed to create MLflow artifact dir: %w", err)
|
||||
}
|
||||
|
||||
port, err := m.opts.PortAllocator.Allocate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
containerName := fmt.Sprintf("mlflow-%s", taskID)
|
||||
cmd := []string{
|
||||
"mlflow", "server",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5000",
|
||||
"--backend-store-uri=file:///mlruns",
|
||||
"--default-artifact-root=file:///mlruns",
|
||||
}
|
||||
|
||||
cfg := &container.ContainerConfig{
|
||||
Name: containerName,
|
||||
Image: m.opts.Image,
|
||||
Command: cmd,
|
||||
Env: map[string]string{
|
||||
"MLFLOW_BACKEND_STORE_URI": "file:///mlruns",
|
||||
"MLFLOW_ARTIFACT_URI": "file:///mlruns",
|
||||
},
|
||||
Volumes: map[string]string{
|
||||
taskDir: "/mlruns",
|
||||
},
|
||||
Ports: map[int]int{
|
||||
port: 5000,
|
||||
},
|
||||
Network: container.NetworkConfig{
|
||||
AllowNetwork: true,
|
||||
},
|
||||
}
|
||||
|
||||
containerID, err := m.podman.StartContainer(ctx, cfg)
|
||||
if err != nil {
|
||||
m.opts.PortAllocator.Release(port)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.sidecars[taskID] = &mlflowSidecar{
|
||||
containerID: containerID,
|
||||
port: port,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return map[string]string{
|
||||
"MLFLOW_TRACKING_URI": fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Teardown stops the MLflow sidecar for a task.
|
||||
func (m *MLflowPlugin) Teardown(ctx context.Context, taskID string) error {
|
||||
m.mu.Lock()
|
||||
sidecar, ok := m.sidecars[taskID]
|
||||
if ok {
|
||||
delete(m.sidecars, taskID)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if sidecar.port != 0 {
|
||||
m.opts.PortAllocator.Release(sidecar.port)
|
||||
}
|
||||
|
||||
if sidecar.containerID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.podman.StopContainer(ctx, sidecar.containerID); err != nil && m.logger != nil {
|
||||
m.logger.Warn("failed to stop MLflow container", "task_id", taskID, "error", err)
|
||||
}
|
||||
if err := m.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && m.logger != nil {
|
||||
m.logger.Warn("failed to remove MLflow container", "task_id", taskID, "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck verifies plugin readiness (basic for now).
|
||||
func (m *MLflowPlugin) HealthCheck(_ context.Context, _ trackingpkg.ToolConfig) bool {
|
||||
// Future: Perform HTTP checks against tracking URI.
|
||||
return true
|
||||
}
|
||||
170
internal/tracking/plugins/tensorboard.go
Normal file
170
internal/tracking/plugins/tensorboard.go
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
package plugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
)
|
||||
|
||||
// TensorBoardOptions configure the TensorBoard plugin.
|
||||
type TensorBoardOptions struct {
|
||||
Image string
|
||||
LogBasePath string
|
||||
PortAllocator *trackingpkg.PortAllocator
|
||||
}
|
||||
|
||||
type tensorboardSidecar struct {
|
||||
containerID string
|
||||
port int
|
||||
}
|
||||
|
||||
// TensorBoardPlugin exposes training logs through TensorBoard.
|
||||
type TensorBoardPlugin struct {
|
||||
logger *logging.Logger
|
||||
podman *container.PodmanManager
|
||||
opts TensorBoardOptions
|
||||
|
||||
mu sync.Mutex
|
||||
sidecars map[string]*tensorboardSidecar
|
||||
}
|
||||
|
||||
// NewTensorBoardPlugin constructs a TensorBoard plugin instance.
|
||||
func NewTensorBoardPlugin(
|
||||
logger *logging.Logger,
|
||||
podman *container.PodmanManager,
|
||||
opts TensorBoardOptions,
|
||||
) (*TensorBoardPlugin, error) {
|
||||
if podman == nil {
|
||||
return nil, fmt.Errorf("podman manager is required for TensorBoard plugin")
|
||||
}
|
||||
if opts.Image == "" {
|
||||
opts.Image = "tensorflow/tensorflow:2.17.0"
|
||||
}
|
||||
if opts.LogBasePath == "" {
|
||||
return nil, fmt.Errorf("log base path is required for TensorBoard plugin")
|
||||
}
|
||||
if opts.PortAllocator == nil {
|
||||
opts.PortAllocator = trackingpkg.NewPortAllocator(5700, 5900)
|
||||
}
|
||||
|
||||
return &TensorBoardPlugin{
|
||||
logger: logger,
|
||||
podman: podman,
|
||||
opts: opts,
|
||||
sidecars: make(map[string]*tensorboardSidecar),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the plugin name.
|
||||
func (t *TensorBoardPlugin) Name() string {
|
||||
return "tensorboard"
|
||||
}
|
||||
|
||||
// ProvisionSidecar starts TensorBoard to tail a task's log directory.
|
||||
func (t *TensorBoardPlugin) ProvisionSidecar(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
config trackingpkg.ToolConfig,
|
||||
) (map[string]string, error) {
|
||||
if config.Mode == trackingpkg.ModeRemote || config.Mode == trackingpkg.ModeDisabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
jobName := trackingpkg.StringSetting(config.Settings, "job_name")
|
||||
if jobName == "" {
|
||||
jobName = taskID
|
||||
}
|
||||
|
||||
logDir := filepath.Join(t.opts.LogBasePath, jobName)
|
||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure tensorboard log dir: %w", err)
|
||||
}
|
||||
|
||||
port, err := t.opts.PortAllocator.Allocate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
containerName := fmt.Sprintf("tensorboard-%s", taskID)
|
||||
cmd := []string{
|
||||
"tensorboard",
|
||||
"--logdir", "/logs",
|
||||
"--host", "0.0.0.0",
|
||||
"--port", "6006",
|
||||
}
|
||||
|
||||
cfg := &container.ContainerConfig{
|
||||
Name: containerName,
|
||||
Image: t.opts.Image,
|
||||
Command: cmd,
|
||||
Volumes: map[string]string{
|
||||
logDir: "/logs:ro",
|
||||
},
|
||||
Ports: map[int]int{
|
||||
port: 6006,
|
||||
},
|
||||
Network: container.NetworkConfig{
|
||||
AllowNetwork: true,
|
||||
},
|
||||
}
|
||||
|
||||
containerID, err := t.podman.StartContainer(ctx, cfg)
|
||||
if err != nil {
|
||||
t.opts.PortAllocator.Release(port)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.sidecars[taskID] = &tensorboardSidecar{
|
||||
containerID: containerID,
|
||||
port: port,
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
return map[string]string{
|
||||
"TENSORBOARD_URL": fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
"TENSORBOARD_HOST_LOG_DIR": logDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Teardown stops the TensorBoard sidecar.
|
||||
func (t *TensorBoardPlugin) Teardown(ctx context.Context, taskID string) error {
|
||||
t.mu.Lock()
|
||||
sidecar, ok := t.sidecars[taskID]
|
||||
if ok {
|
||||
delete(t.sidecars, taskID)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if sidecar.port != 0 {
|
||||
t.opts.PortAllocator.Release(sidecar.port)
|
||||
}
|
||||
|
||||
if sidecar.containerID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := t.podman.StopContainer(ctx, sidecar.containerID); err != nil && t.logger != nil {
|
||||
t.logger.Warn("failed to stop tensorboard container", "task_id", taskID, "error", err)
|
||||
}
|
||||
if err := t.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && t.logger != nil {
|
||||
t.logger.Warn("failed to remove tensorboard container", "task_id", taskID, "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck currently returns true; extend with HTTP checks later.
|
||||
func (t *TensorBoardPlugin) HealthCheck(context.Context, trackingpkg.ToolConfig) bool {
|
||||
return true
|
||||
}
|
||||
69
internal/tracking/plugins/wandb.go
Normal file
69
internal/tracking/plugins/wandb.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package plugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
|
||||
)
|
||||
|
||||
// WandbPlugin forwards credentials to the task container without provisioning a sidecar.
|
||||
type WandbPlugin struct{}
|
||||
|
||||
// NewWandbPlugin constructs a Wandb plugin.
|
||||
func NewWandbPlugin() *WandbPlugin {
|
||||
return &WandbPlugin{}
|
||||
}
|
||||
|
||||
// Name returns the plugin name.
|
||||
func (w *WandbPlugin) Name() string {
|
||||
return "wandb"
|
||||
}
|
||||
|
||||
// ProvisionSidecar validates configuration and returns environment variables.
|
||||
func (w *WandbPlugin) ProvisionSidecar(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
config trackingpkg.ToolConfig,
|
||||
) (map[string]string, error) {
|
||||
if config.Mode == trackingpkg.ModeDisabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
apiKey := trackingpkg.StringSetting(config.Settings, "api_key")
|
||||
project := trackingpkg.StringSetting(config.Settings, "project")
|
||||
entity := trackingpkg.StringSetting(config.Settings, "entity")
|
||||
|
||||
if config.Mode == trackingpkg.ModeRemote && apiKey == "" {
|
||||
return nil, fmt.Errorf("wandb remote mode requires api_key")
|
||||
}
|
||||
|
||||
env := map[string]string{}
|
||||
if apiKey != "" {
|
||||
env["WANDB_API_KEY"] = apiKey
|
||||
}
|
||||
if project != "" {
|
||||
env["WANDB_PROJECT"] = project
|
||||
}
|
||||
if entity != "" {
|
||||
env["WANDB_ENTITY"] = entity
|
||||
}
|
||||
|
||||
return env, nil
|
||||
}
|
||||
|
||||
// Teardown is a no-op for Wandb.
|
||||
func (w *WandbPlugin) Teardown(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck ensures required params exist.
|
||||
func (w *WandbPlugin) HealthCheck(_ context.Context, config trackingpkg.ToolConfig) bool {
|
||||
if !config.Enabled {
|
||||
return true
|
||||
}
|
||||
if config.Mode == trackingpkg.ModeRemote {
|
||||
return trackingpkg.StringSetting(config.Settings, "api_key") != ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
Loading…
Reference in a new issue