diff --git a/internal/tracking/factory/factory_test.go b/internal/tracking/factory/factory_test.go new file mode 100644 index 0000000..781aa16 --- /dev/null +++ b/internal/tracking/factory/factory_test.go @@ -0,0 +1,213 @@ +package factory_test + +import ( + "log/slog" + "testing" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/tracking" + "github.com/jfraeys/fetch_ml/internal/tracking/factory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// helper to create a test logger +func testLogger() *logging.Logger { + return logging.NewLogger(slog.LevelInfo, false) +} + +// TestNewPluginLoader tests plugin loader creation +func TestNewPluginLoader(t *testing.T) { + t.Parallel() + + logger := testLogger() + podman := &container.PodmanManager{} + + loader := factory.NewPluginLoader(logger, podman) + require.NotNil(t, loader) +} + +// TestRegisterFactory tests factory registration +func TestRegisterFactory(t *testing.T) { + t.Parallel() + + logger := testLogger() + podman := &container.PodmanManager{} + + loader := factory.NewPluginLoader(logger, podman) + + // Register a custom factory + customFactory := func(l *logging.Logger, p *container.PodmanManager, cfg factory.PluginConfig) (tracking.Plugin, error) { + return nil, nil // Return nil for simplicity + } + + loader.RegisterFactory("custom", customFactory) + // Should not panic +} + +// TestLoadPluginsEmpty tests loading empty plugins +func TestLoadPluginsEmpty(t *testing.T) { + t.Parallel() + + logger := testLogger() + podman := &container.PodmanManager{} + + loader := factory.NewPluginLoader(logger, podman) + registry := tracking.NewRegistry(logger) + + emptyPlugins := make(map[string]factory.PluginConfig) + err := loader.LoadPlugins(emptyPlugins, registry) + require.NoError(t, err) +} + +// TestLoadPluginsDisabled tests that disabled plugins are skipped +func TestLoadPluginsDisabled(t *testing.T) { + t.Parallel() + + logger := testLogger() + podman := &container.PodmanManager{} + + loader := factory.NewPluginLoader(logger, podman) + registry := tracking.NewRegistry(logger) + + plugins := map[string]factory.PluginConfig{ + "mlflow": { + Enabled: false, + Image: "mlflow:latest", + }, + } + + err := loader.LoadPlugins(plugins, registry) + require.NoError(t, err) + // Disabled plugin should not be registered +} + +// TestLoadPluginsUnknown tests that unknown plugins are skipped with warning +func TestLoadPluginsUnknown(t *testing.T) { + t.Parallel() + + logger := testLogger() + podman := &container.PodmanManager{} + + loader := factory.NewPluginLoader(logger, podman) + registry := tracking.NewRegistry(logger) + + plugins := map[string]factory.PluginConfig{ + "unknown-plugin": { + Enabled: true, + Image: "unknown:latest", + }, + } + + err := loader.LoadPlugins(plugins, registry) + require.NoError(t, err) + // Unknown plugin should be skipped but not cause error +} + +// TestPluginConfigStructure tests plugin config fields +func TestPluginConfigStructure(t *testing.T) { + t.Parallel() + + config := factory.PluginConfig{ + Settings: map[string]any{"key": "value"}, + Image: "test-image:latest", + Mode: "sidecar", + LogBasePath: "/var/log", + ArtifactPath: "/artifacts", + Enabled: true, + } + + assert.Equal(t, "test-image:latest", config.Image) + assert.Equal(t, "sidecar", config.Mode) + assert.Equal(t, "/var/log", config.LogBasePath) + assert.Equal(t, "/artifacts", config.ArtifactPath) + assert.True(t, config.Enabled) + assert.Equal(t, "value", config.Settings["key"]) +} + +// TestPluginFactoryType tests the factory function type +func TestPluginFactoryType(t *testing.T) { + t.Parallel() + + // Verify factory function signature + var _ factory.PluginFactory = func( + logger *logging.Logger, + podman *container.PodmanManager, + cfg factory.PluginConfig, + ) (tracking.Plugin, error) { + return nil, nil + } +} + +// TestLoadPluginsMLflow tests loading MLflow plugin config +func TestLoadPluginsMLflow(t *testing.T) { + t.Parallel() + + logger := testLogger() + podman := &container.PodmanManager{} + + loader := factory.NewPluginLoader(logger, podman) + registry := tracking.NewRegistry(logger) + + // MLflow plugin config (disabled to avoid actual container creation) + plugins := map[string]factory.PluginConfig{ + "mlflow": { + Enabled: false, // Disabled to avoid container operations + Image: "mlflow:latest", + ArtifactPath: "/artifacts", + Mode: "sidecar", + Settings: map[string]any{"port": 5000}, + }, + } + + err := loader.LoadPlugins(plugins, registry) + require.NoError(t, err) +} + +// TestLoadPluginsTensorBoard tests loading TensorBoard plugin config +func TestLoadPluginsTensorBoard(t *testing.T) { + t.Parallel() + + logger := testLogger() + podman := &container.PodmanManager{} + + loader := factory.NewPluginLoader(logger, podman) + registry := tracking.NewRegistry(logger) + + // TensorBoard plugin config (disabled) + plugins := map[string]factory.PluginConfig{ + "tensorboard": { + Enabled: false, + Image: "tensorboard:latest", + LogBasePath: "/logs", + Mode: "sidecar", + }, + } + + err := loader.LoadPlugins(plugins, registry) + require.NoError(t, err) +} + +// TestLoadPluginsWandb tests loading Wandb plugin config +func TestLoadPluginsWandb(t *testing.T) { + t.Parallel() + + logger := testLogger() + podman := &container.PodmanManager{} + + loader := factory.NewPluginLoader(logger, podman) + registry := tracking.NewRegistry(logger) + + // Wandb plugin config (disabled) + plugins := map[string]factory.PluginConfig{ + "wandb": { + Enabled: true, // Wandb doesn't require podman + Image: "wandb:latest", + Mode: "native", + }, + } + + err := loader.LoadPlugins(plugins, registry) + require.NoError(t, err) +}