196 lines
4.6 KiB
Go
196 lines
4.6 KiB
Go
package plugins
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
|
|
)
|
|
|
|
// MLflowOptions configures the MLflow plugin.
|
|
type MLflowOptions struct {
|
|
Image string
|
|
ArtifactBasePath string
|
|
DefaultTrackingURI string
|
|
PortAllocator *trackingpkg.PortAllocator
|
|
}
|
|
|
|
type mlflowSidecar struct {
|
|
containerID string
|
|
port int
|
|
}
|
|
|
|
// MLflowPlugin provisions MLflow tracking servers per task.
|
|
type MLflowPlugin struct {
|
|
logger *logging.Logger
|
|
podman *container.PodmanManager
|
|
opts MLflowOptions
|
|
|
|
mu sync.Mutex
|
|
sidecars map[string]*mlflowSidecar
|
|
}
|
|
|
|
// NewMLflowPlugin creates a new MLflow plugin instance.
|
|
func NewMLflowPlugin(
|
|
logger *logging.Logger,
|
|
podman *container.PodmanManager,
|
|
opts MLflowOptions,
|
|
) (*MLflowPlugin, error) {
|
|
if podman == nil {
|
|
return nil, fmt.Errorf("podman manager is required for MLflow plugin")
|
|
}
|
|
if opts.Image == "" {
|
|
opts.Image = "ghcr.io/mlflow/mlflow:v2.16.1"
|
|
}
|
|
if opts.ArtifactBasePath == "" {
|
|
return nil, fmt.Errorf("artifact base path is required for MLflow plugin")
|
|
}
|
|
if opts.PortAllocator == nil {
|
|
opts.PortAllocator = trackingpkg.NewPortAllocator(5500, 5700)
|
|
}
|
|
|
|
return &MLflowPlugin{
|
|
logger: logger,
|
|
podman: podman,
|
|
opts: opts,
|
|
sidecars: make(map[string]*mlflowSidecar),
|
|
}, nil
|
|
}
|
|
|
|
// Name returns the plugin name.
|
|
func (m *MLflowPlugin) Name() string {
|
|
return "mlflow"
|
|
}
|
|
|
|
// ProvisionSidecar starts an MLflow sidecar or returns remote env vars.
|
|
func (m *MLflowPlugin) ProvisionSidecar(
|
|
ctx context.Context,
|
|
taskID string,
|
|
config trackingpkg.ToolConfig,
|
|
) (map[string]string, error) {
|
|
switch config.Mode {
|
|
case trackingpkg.ModeRemote:
|
|
uri := trackingpkg.StringSetting(config.Settings, "tracking_uri")
|
|
if uri == "" {
|
|
uri = m.opts.DefaultTrackingURI
|
|
}
|
|
if uri == "" {
|
|
return nil, fmt.Errorf("mlflow remote mode requires tracking_uri")
|
|
}
|
|
return map[string]string{
|
|
"MLFLOW_TRACKING_URI": uri,
|
|
}, nil
|
|
case trackingpkg.ModeDisabled:
|
|
return nil, nil
|
|
default:
|
|
return m.provisionSidecar(ctx, taskID, config)
|
|
}
|
|
}
|
|
|
|
func (m *MLflowPlugin) provisionSidecar(
|
|
ctx context.Context,
|
|
taskID string,
|
|
config trackingpkg.ToolConfig,
|
|
) (map[string]string, error) {
|
|
jobName := trackingpkg.StringSetting(config.Settings, "job_name")
|
|
if jobName == "" {
|
|
jobName = taskID
|
|
}
|
|
|
|
taskDir := filepath.Join(m.opts.ArtifactBasePath, jobName)
|
|
if err := os.MkdirAll(taskDir, 0750); err != nil {
|
|
return nil, fmt.Errorf("failed to create MLflow artifact dir: %w", err)
|
|
}
|
|
|
|
port, err := m.opts.PortAllocator.Allocate()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
containerName := fmt.Sprintf("mlflow-%s", taskID)
|
|
cmd := []string{
|
|
"mlflow", "server",
|
|
"--host=0.0.0.0",
|
|
"--port=5000",
|
|
"--backend-store-uri=file:///mlruns",
|
|
"--default-artifact-root=file:///mlruns",
|
|
}
|
|
|
|
cfg := &container.ContainerConfig{
|
|
Name: containerName,
|
|
Image: m.opts.Image,
|
|
Command: cmd,
|
|
Env: map[string]string{
|
|
"MLFLOW_BACKEND_STORE_URI": "file:///mlruns",
|
|
"MLFLOW_ARTIFACT_URI": "file:///mlruns",
|
|
},
|
|
Volumes: map[string]string{
|
|
taskDir: "/mlruns",
|
|
},
|
|
Ports: map[int]int{
|
|
port: 5000,
|
|
},
|
|
Network: container.NetworkConfig{
|
|
AllowNetwork: true,
|
|
},
|
|
}
|
|
|
|
containerID, err := m.podman.StartContainer(ctx, cfg)
|
|
if err != nil {
|
|
m.opts.PortAllocator.Release(port)
|
|
return nil, err
|
|
}
|
|
|
|
m.mu.Lock()
|
|
m.sidecars[taskID] = &mlflowSidecar{
|
|
containerID: containerID,
|
|
port: port,
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
return map[string]string{
|
|
"MLFLOW_TRACKING_URI": fmt.Sprintf("http://127.0.0.1:%d", port),
|
|
}, nil
|
|
}
|
|
|
|
// Teardown stops the MLflow sidecar for a task.
|
|
func (m *MLflowPlugin) Teardown(ctx context.Context, taskID string) error {
|
|
m.mu.Lock()
|
|
sidecar, ok := m.sidecars[taskID]
|
|
if ok {
|
|
delete(m.sidecars, taskID)
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if sidecar.port != 0 {
|
|
m.opts.PortAllocator.Release(sidecar.port)
|
|
}
|
|
|
|
if sidecar.containerID == "" {
|
|
return nil
|
|
}
|
|
|
|
if err := m.podman.StopContainer(ctx, sidecar.containerID); err != nil && m.logger != nil {
|
|
m.logger.Warn("failed to stop MLflow container", "task_id", taskID, "error", err)
|
|
}
|
|
if err := m.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && m.logger != nil {
|
|
m.logger.Warn("failed to remove MLflow container", "task_id", taskID, "error", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HealthCheck verifies plugin readiness (basic for now).
|
|
func (m *MLflowPlugin) HealthCheck(_ context.Context, _ trackingpkg.ToolConfig) bool {
|
|
// Future: Perform HTTP checks against tracking URI.
|
|
return true
|
|
}
|