package plugins_test import ( "context" "errors" "testing" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/tracking" "github.com/jfraeys/fetch_ml/internal/tracking/plugins" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // mockPodmanManager implements container.PodmanInterface for testing type mockPodmanManager struct { startFunc func(ctx context.Context, cfg *container.ContainerConfig) (string, error) stopFunc func(ctx context.Context, containerID string) error removeFunc func(ctx context.Context, containerID string) error containers map[string]*container.ContainerConfig } func newMockPodmanManager() *mockPodmanManager { return &mockPodmanManager{ containers: make(map[string]*container.ContainerConfig), } } func (m *mockPodmanManager) StartContainer(ctx context.Context, cfg *container.ContainerConfig) (string, error) { if m.startFunc != nil { return m.startFunc(ctx, cfg) } id := "mock-container-" + cfg.Name m.containers[id] = cfg return id, nil } func (m *mockPodmanManager) StopContainer(ctx context.Context, containerID string) error { if m.stopFunc != nil { return m.stopFunc(ctx, containerID) } return nil } func (m *mockPodmanManager) RemoveContainer(ctx context.Context, containerID string) error { if m.removeFunc != nil { return m.removeFunc(ctx, containerID) } delete(m.containers, containerID) return nil } // TestNewMLflowPluginNilPodman tests creation with nil podman func TestNewMLflowPluginNilPodman(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", } _, err := plugins.NewMLflowPlugin(logger, nil, opts) require.Error(t, err) assert.Contains(t, err.Error(), "podman manager is required") } // TestNewMLflowPluginEmptyArtifactPath tests creation with empty artifact path func TestNewMLflowPluginEmptyArtifactPath(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{} _, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.Error(t, err) assert.Contains(t, err.Error(), "artifact base path is required") } // TestNewMLflowPluginDefaults tests default values func TestNewMLflowPluginDefaults(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) require.NotNil(t, plugin) } // TestMLflowPluginName tests plugin name func TestMLflowPluginName(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) assert.Equal(t, "mlflow", plugin.Name()) } // TestMLflowPluginProvisionSidecarDisabled tests disabled mode func TestMLflowPluginProvisionSidecarDisabled(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) config := tracking.ToolConfig{ Enabled: false, Mode: tracking.ModeDisabled, } env, err := plugin.ProvisionSidecar(context.Background(), "task-1", config) require.NoError(t, err) assert.Nil(t, env) } // TestMLflowPluginProvisionSidecarRemoteNoURI tests remote mode without URI func TestMLflowPluginProvisionSidecarRemoteNoURI(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) config := tracking.ToolConfig{ Enabled: true, Mode: tracking.ModeRemote, Settings: map[string]any{}, } _, err = plugin.ProvisionSidecar(context.Background(), "task-1", config) require.Error(t, err) assert.Contains(t, err.Error(), "tracking_uri") } // TestMLflowPluginProvisionSidecarRemoteWithURI tests remote mode with URI func TestMLflowPluginProvisionSidecarRemoteWithURI(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", DefaultTrackingURI: "http://default:5000", } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) config := tracking.ToolConfig{ Enabled: true, Mode: tracking.ModeRemote, Settings: map[string]any{ "tracking_uri": "http://custom:5000", }, } env, err := plugin.ProvisionSidecar(context.Background(), "task-1", config) require.NoError(t, err) require.NotNil(t, env) assert.Equal(t, "http://custom:5000", env["MLFLOW_TRACKING_URI"]) } // TestMLflowPluginProvisionSidecarRemoteWithDefaultURI tests remote mode with default URI func TestMLflowPluginProvisionSidecarRemoteWithDefaultURI(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", DefaultTrackingURI: "http://default:5000", } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) config := tracking.ToolConfig{ Enabled: true, Mode: tracking.ModeRemote, Settings: map[string]any{}, } env, err := plugin.ProvisionSidecar(context.Background(), "task-1", config) require.NoError(t, err) require.NotNil(t, env) assert.Equal(t, "http://default:5000", env["MLFLOW_TRACKING_URI"]) } // TestMLflowPluginProvisionSidecarSidecarMode tests sidecar mode (container creation) func TestMLflowPluginProvisionSidecarSidecarMode(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() allocator := tracking.NewPortAllocator(5500, 5700) opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", PortAllocator: allocator, } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) config := tracking.ToolConfig{ Enabled: true, Mode: tracking.ModeSidecar, Settings: map[string]any{ "job_name": "test-job", }, } env, err := plugin.ProvisionSidecar(context.Background(), "task-1", config) require.NoError(t, err) require.NotNil(t, env) assert.Contains(t, env, "MLFLOW_TRACKING_URI") } // TestMLflowPluginProvisionSidecarStartFailure tests container start failure func TestMLflowPluginProvisionSidecarStartFailure(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() mockPodman.startFunc = func(ctx context.Context, cfg *container.ContainerConfig) (string, error) { return "", errors.New("failed to start container") } allocator := tracking.NewPortAllocator(5500, 5700) opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", PortAllocator: allocator, } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) config := tracking.ToolConfig{ Enabled: true, Mode: tracking.ModeSidecar, Settings: map[string]any{ "job_name": "test-job", }, } _, err = plugin.ProvisionSidecar(context.Background(), "task-1", config) require.Error(t, err) assert.Contains(t, err.Error(), "failed to start") } // TestMLflowPluginTeardownNonexistent tests teardown for nonexistent task func TestMLflowPluginTeardownNonexistent(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) err = plugin.Teardown(context.Background(), "nonexistent-task") require.NoError(t, err) } // TestMLflowPluginTeardownWithSidecar tests teardown with running sidecar func TestMLflowPluginTeardownWithSidecar(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() allocator := tracking.NewPortAllocator(5500, 5700) opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", PortAllocator: allocator, } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) // Create a sidecar first config := tracking.ToolConfig{ Enabled: true, Mode: tracking.ModeSidecar, Settings: map[string]any{ "job_name": "test-job", }, } _, err = plugin.ProvisionSidecar(context.Background(), "task-1", config) require.NoError(t, err) // Now teardown err = plugin.Teardown(context.Background(), "task-1") require.NoError(t, err) } // TestMLflowPluginHealthCheck tests health check func TestMLflowPluginHealthCheck(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) // Health check always returns true for now healthy := plugin.HealthCheck(context.Background(), tracking.ToolConfig{}) assert.True(t, healthy) } // TestMLflowPluginCustomImage tests custom image option func TestMLflowPluginCustomImage(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", Image: "custom/mlflow:latest", PortAllocator: tracking.NewPortAllocator(5500, 5700), } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) require.NotNil(t, plugin) // Provision sidecar and verify custom image is used config := tracking.ToolConfig{ Enabled: true, Mode: tracking.ModeSidecar, Settings: map[string]any{ "job_name": "test-job", }, } _, err = plugin.ProvisionSidecar(context.Background(), "task-1", config) require.NoError(t, err) // Verify the custom image was used in container config require.NotEmpty(t, mockPodman.containers, "container should have been created") for _, cfg := range mockPodman.containers { assert.Equal(t, "custom/mlflow:latest", cfg.Image, "custom image should be used") } } // TestMLflowPluginDefaultImage tests that default image is set func TestMLflowPluginDefaultImage(t *testing.T) { t.Parallel() logger := logging.NewLogger(0, false) mockPodman := newMockPodmanManager() opts := plugins.MLflowOptions{ ArtifactBasePath: "/tmp/mlflow", PortAllocator: tracking.NewPortAllocator(5500, 5700), // Image not specified - should default to ghcr.io/mlflow/mlflow:v2.16.1 } plugin, err := plugins.NewMLflowPlugin(logger, mockPodman, opts) require.NoError(t, err) require.NotNil(t, plugin) // Provision sidecar and verify default image is used config := tracking.ToolConfig{ Enabled: true, Mode: tracking.ModeSidecar, Settings: map[string]any{ "job_name": "test-job", }, } _, err = plugin.ProvisionSidecar(context.Background(), "task-1", config) require.NoError(t, err) // Verify the default image was used in container config require.NotEmpty(t, mockPodman.containers, "container should have been created") for _, cfg := range mockPodman.containers { assert.Equal(t, "ghcr.io/mlflow/mlflow:v2.16.1", cfg.Image, "default image should be used") } }