fetch_ml/internal/tracking/plugins/mlflow.go

196 lines
4.6 KiB
Go

package plugins
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/logging"
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
)
// MLflowOptions configures the MLflow plugin.
type MLflowOptions struct {
Image string
ArtifactBasePath string
DefaultTrackingURI string
PortAllocator *trackingpkg.PortAllocator
}
type mlflowSidecar struct {
containerID string
port int
}
// MLflowPlugin provisions MLflow tracking servers per task.
type MLflowPlugin struct {
logger *logging.Logger
podman *container.PodmanManager
opts MLflowOptions
mu sync.Mutex
sidecars map[string]*mlflowSidecar
}
// NewMLflowPlugin creates a new MLflow plugin instance.
func NewMLflowPlugin(
logger *logging.Logger,
podman *container.PodmanManager,
opts MLflowOptions,
) (*MLflowPlugin, error) {
if podman == nil {
return nil, fmt.Errorf("podman manager is required for MLflow plugin")
}
if opts.Image == "" {
opts.Image = "ghcr.io/mlflow/mlflow:v2.16.1"
}
if opts.ArtifactBasePath == "" {
return nil, fmt.Errorf("artifact base path is required for MLflow plugin")
}
if opts.PortAllocator == nil {
opts.PortAllocator = trackingpkg.NewPortAllocator(5500, 5700)
}
return &MLflowPlugin{
logger: logger,
podman: podman,
opts: opts,
sidecars: make(map[string]*mlflowSidecar),
}, nil
}
// Name returns the plugin name.
func (m *MLflowPlugin) Name() string {
return "mlflow"
}
// ProvisionSidecar starts an MLflow sidecar or returns remote env vars.
func (m *MLflowPlugin) ProvisionSidecar(
ctx context.Context,
taskID string,
config trackingpkg.ToolConfig,
) (map[string]string, error) {
switch config.Mode {
case trackingpkg.ModeRemote:
uri := trackingpkg.StringSetting(config.Settings, "tracking_uri")
if uri == "" {
uri = m.opts.DefaultTrackingURI
}
if uri == "" {
return nil, fmt.Errorf("mlflow remote mode requires tracking_uri")
}
return map[string]string{
"MLFLOW_TRACKING_URI": uri,
}, nil
case trackingpkg.ModeDisabled:
return nil, nil
default:
return m.provisionSidecar(ctx, taskID, config)
}
}
func (m *MLflowPlugin) provisionSidecar(
ctx context.Context,
taskID string,
config trackingpkg.ToolConfig,
) (map[string]string, error) {
jobName := trackingpkg.StringSetting(config.Settings, "job_name")
if jobName == "" {
jobName = taskID
}
taskDir := filepath.Join(m.opts.ArtifactBasePath, jobName)
if err := os.MkdirAll(taskDir, 0750); err != nil {
return nil, fmt.Errorf("failed to create MLflow artifact dir: %w", err)
}
port, err := m.opts.PortAllocator.Allocate()
if err != nil {
return nil, err
}
containerName := fmt.Sprintf("mlflow-%s", taskID)
cmd := []string{
"mlflow", "server",
"--host=0.0.0.0",
"--port=5000",
"--backend-store-uri=file:///mlruns",
"--default-artifact-root=file:///mlruns",
}
cfg := &container.ContainerConfig{
Name: containerName,
Image: m.opts.Image,
Command: cmd,
Env: map[string]string{
"MLFLOW_BACKEND_STORE_URI": "file:///mlruns",
"MLFLOW_ARTIFACT_URI": "file:///mlruns",
},
Volumes: map[string]string{
taskDir: "/mlruns",
},
Ports: map[int]int{
port: 5000,
},
Network: container.NetworkConfig{
AllowNetwork: true,
},
}
containerID, err := m.podman.StartContainer(ctx, cfg)
if err != nil {
m.opts.PortAllocator.Release(port)
return nil, err
}
m.mu.Lock()
m.sidecars[taskID] = &mlflowSidecar{
containerID: containerID,
port: port,
}
m.mu.Unlock()
return map[string]string{
"MLFLOW_TRACKING_URI": fmt.Sprintf("http://127.0.0.1:%d", port),
}, nil
}
// Teardown stops the MLflow sidecar for a task.
func (m *MLflowPlugin) Teardown(ctx context.Context, taskID string) error {
m.mu.Lock()
sidecar, ok := m.sidecars[taskID]
if ok {
delete(m.sidecars, taskID)
}
m.mu.Unlock()
if !ok {
return nil
}
if sidecar.port != 0 {
m.opts.PortAllocator.Release(sidecar.port)
}
if sidecar.containerID == "" {
return nil
}
if err := m.podman.StopContainer(ctx, sidecar.containerID); err != nil && m.logger != nil {
m.logger.Warn("failed to stop MLflow container", "task_id", taskID, "error", err)
}
if err := m.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && m.logger != nil {
m.logger.Warn("failed to remove MLflow container", "task_id", taskID, "error", err)
}
return nil
}
// HealthCheck verifies plugin readiness (basic for now).
func (m *MLflowPlugin) HealthCheck(_ context.Context, _ trackingpkg.ToolConfig) bool {
// Future: Perform HTTP checks against tracking URI.
return true
}