package plugins import ( "context" "fmt" "os" "path/filepath" "sync" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/logging" trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking" ) // MLflowOptions configures the MLflow plugin. type MLflowOptions struct { Image string ArtifactBasePath string DefaultTrackingURI string PortAllocator *trackingpkg.PortAllocator } type mlflowSidecar struct { containerID string port int } // MLflowPlugin provisions MLflow tracking servers per task. type MLflowPlugin struct { logger *logging.Logger podman *container.PodmanManager opts MLflowOptions mu sync.Mutex sidecars map[string]*mlflowSidecar } // NewMLflowPlugin creates a new MLflow plugin instance. func NewMLflowPlugin( logger *logging.Logger, podman *container.PodmanManager, opts MLflowOptions, ) (*MLflowPlugin, error) { if podman == nil { return nil, fmt.Errorf("podman manager is required for MLflow plugin") } if opts.Image == "" { opts.Image = "ghcr.io/mlflow/mlflow:v2.16.1" } if opts.ArtifactBasePath == "" { return nil, fmt.Errorf("artifact base path is required for MLflow plugin") } if opts.PortAllocator == nil { opts.PortAllocator = trackingpkg.NewPortAllocator(5500, 5700) } return &MLflowPlugin{ logger: logger, podman: podman, opts: opts, sidecars: make(map[string]*mlflowSidecar), }, nil } // Name returns the plugin name. func (m *MLflowPlugin) Name() string { return "mlflow" } // ProvisionSidecar starts an MLflow sidecar or returns remote env vars. func (m *MLflowPlugin) ProvisionSidecar( ctx context.Context, taskID string, config trackingpkg.ToolConfig, ) (map[string]string, error) { switch config.Mode { case trackingpkg.ModeRemote: uri := trackingpkg.StringSetting(config.Settings, "tracking_uri") if uri == "" { uri = m.opts.DefaultTrackingURI } if uri == "" { return nil, fmt.Errorf("mlflow remote mode requires tracking_uri") } return map[string]string{ "MLFLOW_TRACKING_URI": uri, }, nil case trackingpkg.ModeDisabled: return nil, nil default: return m.provisionSidecar(ctx, taskID, config) } } func (m *MLflowPlugin) provisionSidecar( ctx context.Context, taskID string, config trackingpkg.ToolConfig, ) (map[string]string, error) { jobName := trackingpkg.StringSetting(config.Settings, "job_name") if jobName == "" { jobName = taskID } taskDir := filepath.Join(m.opts.ArtifactBasePath, jobName) if err := os.MkdirAll(taskDir, 0750); err != nil { return nil, fmt.Errorf("failed to create MLflow artifact dir: %w", err) } port, err := m.opts.PortAllocator.Allocate() if err != nil { return nil, err } containerName := fmt.Sprintf("mlflow-%s", taskID) cmd := []string{ "mlflow", "server", "--host=0.0.0.0", "--port=5000", "--backend-store-uri=file:///mlruns", "--default-artifact-root=file:///mlruns", } cfg := &container.ContainerConfig{ Name: containerName, Image: m.opts.Image, Command: cmd, Env: map[string]string{ "MLFLOW_BACKEND_STORE_URI": "file:///mlruns", "MLFLOW_ARTIFACT_URI": "file:///mlruns", }, Volumes: map[string]string{ taskDir: "/mlruns", }, Ports: map[int]int{ port: 5000, }, Network: container.NetworkConfig{ AllowNetwork: true, }, } containerID, err := m.podman.StartContainer(ctx, cfg) if err != nil { m.opts.PortAllocator.Release(port) return nil, err } m.mu.Lock() m.sidecars[taskID] = &mlflowSidecar{ containerID: containerID, port: port, } m.mu.Unlock() return map[string]string{ "MLFLOW_TRACKING_URI": fmt.Sprintf("http://127.0.0.1:%d", port), }, nil } // Teardown stops the MLflow sidecar for a task. func (m *MLflowPlugin) Teardown(ctx context.Context, taskID string) error { m.mu.Lock() sidecar, ok := m.sidecars[taskID] if ok { delete(m.sidecars, taskID) } m.mu.Unlock() if !ok { return nil } if sidecar.port != 0 { m.opts.PortAllocator.Release(sidecar.port) } if sidecar.containerID == "" { return nil } if err := m.podman.StopContainer(ctx, sidecar.containerID); err != nil && m.logger != nil { m.logger.Warn("failed to stop MLflow container", "task_id", taskID, "error", err) } if err := m.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && m.logger != nil { m.logger.Warn("failed to remove MLflow container", "task_id", taskID, "error", err) } return nil } // HealthCheck verifies plugin readiness (basic for now). func (m *MLflowPlugin) HealthCheck(_ context.Context, _ trackingpkg.ToolConfig) bool { // Future: Perform HTTP checks against tracking URI. return true }