fetch_ml/internal/tracking/plugins/mlflow_test.go
Jeremie Fraeys f827ee522a
test(tracking/plugins): add PodmanInterface and comprehensive plugin tests for 91% coverage
Refactor plugins to use interface for testability:
- Add PodmanInterface to container package (StartContainer, StopContainer, RemoveContainer)
- Update MLflow plugin to use container.PodmanInterface
- Update TensorBoard plugin to use container.PodmanInterface
- Add comprehensive mocked tests for all three plugins (wandb, mlflow, tensorboard)
- Coverage increased from 18% to 91.4%
2026-03-14 16:59:16 -04:00

370 lines
10 KiB
Go

package plugins_test
import (
"context"
"errors"
"testing"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/tracking"
"github.com/jfraeys/fetch_ml/internal/tracking/plugins"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockPodmanManager implements container.PodmanInterface for testing
type mockPodmanManager struct {
startFunc func(ctx context.Context, cfg *container.ContainerConfig) (string, error)
stopFunc func(ctx context.Context, containerID string) error
removeFunc func(ctx context.Context, containerID string) error
containers map[string]*container.ContainerConfig
}
func newMockPodmanManager() *mockPodmanManager {
return &mockPodmanManager{
containers: make(map[string]*container.ContainerConfig),
}
}
func (m *mockPodmanManager) StartContainer(ctx context.Context, cfg *container.ContainerConfig) (string, error) {
if m.startFunc != nil {
return m.startFunc(ctx, cfg)
}
id := "mock-container-" + cfg.Name
m.containers[id] = cfg
return id, nil
}
func (m *mockPodmanManager) StopContainer(ctx context.Context, containerID string) error {
if m.stopFunc != nil {
return m.stopFunc(ctx, containerID)
}
return nil
}
func (m *mockPodmanManager) RemoveContainer(ctx context.Context, containerID string) error {
if m.removeFunc != nil {
return m.removeFunc(ctx, containerID)
}
delete(m.containers, containerID)
return nil
}
// TestNewMLflowPluginNilPodman tests creation with nil podman
func TestNewMLflowPluginNilPodman(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
}
_, err := plugins.NewMLflowPlugin(logger, nil, opts)
require.Error(t, err)
assert.Contains(t, err.Error(), "podman manager is required")
}
// TestNewMLflowPluginEmptyArtifactPath tests creation with empty artifact path
func TestNewMLflowPluginEmptyArtifactPath(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{}
_, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.Error(t, err)
assert.Contains(t, err.Error(), "artifact base path is required")
}
// TestNewMLflowPluginDefaults tests default values
func TestNewMLflowPluginDefaults(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
require.NotNil(t, plugin)
}
// TestMLflowPluginName tests plugin name
func TestMLflowPluginName(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
assert.Equal(t, "mlflow", plugin.Name())
}
// TestMLflowPluginProvisionSidecarDisabled tests disabled mode
func TestMLflowPluginProvisionSidecarDisabled(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
config := tracking.ToolConfig{
Enabled: false,
Mode: tracking.ModeDisabled,
}
env, err := plugin.ProvisionSidecar(context.Background(), "task-1", config)
require.NoError(t, err)
assert.Nil(t, env)
}
// TestMLflowPluginProvisionSidecarRemoteNoURI tests remote mode without URI
func TestMLflowPluginProvisionSidecarRemoteNoURI(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
config := tracking.ToolConfig{
Enabled: true,
Mode: tracking.ModeRemote,
Settings: map[string]any{},
}
_, err = plugin.ProvisionSidecar(context.Background(), "task-1", config)
require.Error(t, err)
assert.Contains(t, err.Error(), "tracking_uri")
}
// TestMLflowPluginProvisionSidecarRemoteWithURI tests remote mode with URI
func TestMLflowPluginProvisionSidecarRemoteWithURI(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
DefaultTrackingURI: "http://default:5000",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
config := tracking.ToolConfig{
Enabled: true,
Mode: tracking.ModeRemote,
Settings: map[string]any{
"tracking_uri": "http://custom:5000",
},
}
env, err := plugin.ProvisionSidecar(context.Background(), "task-1", config)
require.NoError(t, err)
require.NotNil(t, env)
assert.Equal(t, "http://custom:5000", env["MLFLOW_TRACKING_URI"])
}
// TestMLflowPluginProvisionSidecarRemoteWithDefaultURI tests remote mode with default URI
func TestMLflowPluginProvisionSidecarRemoteWithDefaultURI(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
DefaultTrackingURI: "http://default:5000",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
config := tracking.ToolConfig{
Enabled: true,
Mode: tracking.ModeRemote,
Settings: map[string]any{},
}
env, err := plugin.ProvisionSidecar(context.Background(), "task-1", config)
require.NoError(t, err)
require.NotNil(t, env)
assert.Equal(t, "http://default:5000", env["MLFLOW_TRACKING_URI"])
}
// TestMLflowPluginProvisionSidecarSidecarMode tests sidecar mode (container creation)
func TestMLflowPluginProvisionSidecarSidecarMode(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
allocator := tracking.NewPortAllocator(5500, 5700)
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
PortAllocator: allocator,
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
config := tracking.ToolConfig{
Enabled: true,
Mode: tracking.ModeSidecar,
Settings: map[string]any{
"job_name": "test-job",
},
}
env, err := plugin.ProvisionSidecar(context.Background(), "task-1", config)
require.NoError(t, err)
require.NotNil(t, env)
assert.Contains(t, env, "MLFLOW_TRACKING_URI")
}
// TestMLflowPluginProvisionSidecarStartFailure tests container start failure
func TestMLflowPluginProvisionSidecarStartFailure(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
mockPodman.startFunc = func(ctx context.Context, cfg *container.ContainerConfig) (string, error) {
return "", errors.New("failed to start container")
}
allocator := tracking.NewPortAllocator(5500, 5700)
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
PortAllocator: allocator,
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
config := tracking.ToolConfig{
Enabled: true,
Mode: tracking.ModeSidecar,
Settings: map[string]any{
"job_name": "test-job",
},
}
_, err = plugin.ProvisionSidecar(context.Background(), "task-1", config)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to start")
}
// TestMLflowPluginTeardownNonexistent tests teardown for nonexistent task
func TestMLflowPluginTeardownNonexistent(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
err = plugin.Teardown(context.Background(), "nonexistent-task")
require.NoError(t, err)
}
// TestMLflowPluginTeardownWithSidecar tests teardown with running sidecar
func TestMLflowPluginTeardownWithSidecar(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
allocator := tracking.NewPortAllocator(5500, 5700)
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
PortAllocator: allocator,
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
// Create a sidecar first
config := tracking.ToolConfig{
Enabled: true,
Mode: tracking.ModeSidecar,
Settings: map[string]any{
"job_name": "test-job",
},
}
_, err = plugin.ProvisionSidecar(context.Background(), "task-1", config)
require.NoError(t, err)
// Now teardown
err = plugin.Teardown(context.Background(), "task-1")
require.NoError(t, err)
}
// TestMLflowPluginHealthCheck tests health check
func TestMLflowPluginHealthCheck(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
// Health check always returns true for now
healthy := plugin.HealthCheck(context.Background(), tracking.ToolConfig{})
assert.True(t, healthy)
}
// TestMLflowPluginCustomImage tests custom image option
func TestMLflowPluginCustomImage(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
Image: "custom/mlflow:latest",
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
require.NotNil(t, plugin)
}
// TestMLflowPluginDefaultImage tests that default image is set
func TestMLflowPluginDefaultImage(t *testing.T) {
t.Parallel()
logger := logging.NewLogger(0, false)
mockPodman := newMockPodmanManager()
opts := plugins.MLflowOptions{
ArtifactBasePath: "/tmp/mlflow",
// Image not specified - should default to ghcr.io/mlflow/mlflow:v2.16.1
}
plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts)
require.NoError(t, err)
require.NotNil(t, plugin)
// Plugin was created successfully with default image
}