diff --git a/internal/tracking/plugin_test.go b/internal/tracking/plugin_test.go new file mode 100644 index 0000000..badf6b6 --- /dev/null +++ b/internal/tracking/plugin_test.go @@ -0,0 +1,267 @@ +package tracking_test + +import ( + "context" + "log/slog" + "testing" + + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/tracking" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockPlugin is a test plugin implementation +type mockPlugin struct { + name string + provisioned bool + tornDown bool +} + +func (m *mockPlugin) Name() string { return m.name } +func (m *mockPlugin) ProvisionSidecar(ctx context.Context, taskID string, config tracking.ToolConfig) (map[string]string, error) { + m.provisioned = true + return map[string]string{"MOCK_VAR": "value"}, nil +} +func (m *mockPlugin) Teardown(ctx context.Context, taskID string) error { + m.tornDown = true + return nil +} +func (m *mockPlugin) HealthCheck(ctx context.Context, config tracking.ToolConfig) bool { return true } + +// TestNewRegistry tests registry creation +func TestNewRegistry(t *testing.T) { + t.Parallel() + + logger := logging.NewLogger(slog.LevelInfo, false) + registry := tracking.NewRegistry(logger) + require.NotNil(t, registry) +} + +// TestRegisterAndGet tests plugin registration and retrieval +func TestRegisterAndGet(t *testing.T) { + t.Parallel() + + logger := logging.NewLogger(slog.LevelInfo, false) + registry := tracking.NewRegistry(logger) + + plugin := &mockPlugin{name: "test-plugin"} + registry.Register(plugin) + + retrieved, ok := registry.Get("test-plugin") + require.True(t, ok) + assert.Equal(t, "test-plugin", retrieved.Name()) + + // Nonexistent plugin + _, ok = registry.Get("nonexistent") + assert.False(t, ok) +} + +// TestProvisionAllEmpty tests empty config provisioning +func TestProvisionAllEmpty(t *testing.T) { + t.Parallel() + + logger := logging.NewLogger(slog.LevelInfo, false) + registry := tracking.NewRegistry(logger) + + env, err := registry.ProvisionAll(context.Background(), "task-1", nil) + require.NoError(t, err) + assert.Nil(t, env) + + env, err = registry.ProvisionAll(context.Background(), "task-1", map[string]tracking.ToolConfig{}) + require.NoError(t, err) + assert.Nil(t, env) +} + +// TestProvisionAllDisabled tests disabled plugin handling +func TestProvisionAllDisabled(t *testing.T) { + t.Parallel() + + logger := logging.NewLogger(slog.LevelInfo, false) + registry := tracking.NewRegistry(logger) + + plugin := &mockPlugin{name: "disabled-plugin"} + registry.Register(plugin) + + configs := map[string]tracking.ToolConfig{ + "disabled-plugin": {Enabled: false, Mode: tracking.ModeSidecar}, + } + + env, err := registry.ProvisionAll(context.Background(), "task-1", configs) + require.NoError(t, err) + assert.Empty(t, env) + assert.False(t, plugin.provisioned) +} + +// TestProvisionAllUnregistered tests unregistered plugin error +func TestProvisionAllUnregistered(t *testing.T) { + t.Parallel() + + logger := logging.NewLogger(slog.LevelInfo, false) + registry := tracking.NewRegistry(logger) + + configs := map[string]tracking.ToolConfig{ + "unregistered": {Enabled: true, Mode: tracking.ModeSidecar}, + } + + _, err := registry.ProvisionAll(context.Background(), "task-1", configs) + require.Error(t, err) + assert.Contains(t, err.Error(), "not registered") +} + +// TestTeardownAll tests plugin teardown +func TestTeardownAll(t *testing.T) { + t.Parallel() + + logger := logging.NewLogger(slog.LevelInfo, false) + registry := tracking.NewRegistry(logger) + + plugin := &mockPlugin{name: "teardown-test"} + registry.Register(plugin) + + // Provision first + configs := map[string]tracking.ToolConfig{ + "teardown-test": {Enabled: true, Mode: tracking.ModeSidecar}, + } + _, err := registry.ProvisionAll(context.Background(), "task-1", configs) + require.NoError(t, err) + + // Then teardown + registry.TeardownAll(context.Background(), "task-1") + assert.True(t, plugin.tornDown) + + // Teardown nonexistent task should not panic + registry.TeardownAll(context.Background(), "nonexistent-task") +} + +// TestNewPortAllocator tests allocator creation +func TestNewPortAllocator(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + start int + end int + wantStart int + wantEnd int + }{ + {"valid range", 5000, 5100, 5000, 5100}, + {"zero values", 0, 0, 5500, 5600}, + {"invalid range", 100, 50, 5500, 5600}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + allocator := tracking.NewPortAllocator(tc.start, tc.end) + require.NotNil(t, allocator) + + port, err := allocator.Allocate() + require.NoError(t, err) + assert.GreaterOrEqual(t, port, tc.wantStart) + assert.Less(t, port, tc.wantEnd) + }) + } +} + +// TestPortAllocatorAllocate tests port allocation +func TestPortAllocatorAllocate(t *testing.T) { + t.Parallel() + + allocator := tracking.NewPortAllocator(8000, 8003) + + // Allocate all ports + ports := make([]int, 0, 3) + for i := 0; i < 3; i++ { + port, err := allocator.Allocate() + require.NoError(t, err) + ports = append(ports, port) + } + + // Should be unique + assert.Len(t, ports, 3) + assert.NotEqual(t, ports[0], ports[1]) + + // No more ports available + _, err := allocator.Allocate() + require.Error(t, err) +} + +// TestPortAllocatorRelease tests port release and reuse +func TestPortAllocatorRelease(t *testing.T) { + t.Parallel() + + allocator := tracking.NewPortAllocator(9000, 9005) + + port1, err := allocator.Allocate() + require.NoError(t, err) + + // Allocate another port to move next pointer forward + port2, err := allocator.Allocate() + require.NoError(t, err) + require.NotEqual(t, port1, port2) + + // Release the first port + allocator.Release(port1) + + // Allocate again - should eventually get the released port back + // (after scanning through other ports) + var foundReleased bool + for i := 0; i < 10; i++ { + p, err := allocator.Allocate() + require.NoError(t, err) + if p == port1 { + foundReleased = true + break + } + // Release it to keep scanning + allocator.Release(p) + } + assert.True(t, foundReleased, "Should eventually get released port back") +} + +// TestStringSetting tests settings extraction +func TestStringSetting(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + settings map[string]any + key string + want string + }{ + {"valid string", map[string]any{"key": "value"}, "key", "value"}, + {"nil settings", nil, "key", ""}, + {"missing key", map[string]any{"other": "value"}, "key", ""}, + {"non-string value", map[string]any{"key": 123}, "key", ""}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := tracking.StringSetting(tc.settings, tc.key) + assert.Equal(t, tc.want, got) + }) + } +} + +// TestToolModeConstants tests mode constants +func TestToolModeConstants(t *testing.T) { + t.Parallel() + + assert.Equal(t, tracking.ToolMode("sidecar"), tracking.ModeSidecar) + assert.Equal(t, tracking.ToolMode("remote"), tracking.ModeRemote) + assert.Equal(t, tracking.ToolMode("disabled"), tracking.ModeDisabled) +} + +// TestToolConfigStructure tests config fields +func TestToolConfigStructure(t *testing.T) { + t.Parallel() + + config := tracking.ToolConfig{ + Settings: map[string]any{"port": 8080}, + Mode: tracking.ModeSidecar, + Enabled: true, + } + + assert.True(t, config.Enabled) + assert.Equal(t, tracking.ModeSidecar, config.Mode) +}