Refactor plugins to use interface for testability: - Add PodmanInterface to container package (StartContainer, StopContainer, RemoveContainer) - Update MLflow plugin to use container.PodmanInterface - Update TensorBoard plugin to use container.PodmanInterface - Add comprehensive mocked tests for all three plugins (wandb, mlflow, tensorboard) - Coverage increased from 18% to 91.4%
267 lines
6.9 KiB
Go
267 lines
6.9 KiB
Go
package tracking_test
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/tracking"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// mockPlugin is a test plugin implementation
|
|
type mockPlugin struct {
|
|
name string
|
|
provisioned bool
|
|
tornDown bool
|
|
}
|
|
|
|
func (m *mockPlugin) Name() string { return m.name }
|
|
func (m *mockPlugin) ProvisionSidecar(ctx context.Context, taskID string, config tracking.ToolConfig) (map[string]string, error) {
|
|
m.provisioned = true
|
|
return map[string]string{"MOCK_VAR": "value"}, nil
|
|
}
|
|
func (m *mockPlugin) Teardown(ctx context.Context, taskID string) error {
|
|
m.tornDown = true
|
|
return nil
|
|
}
|
|
func (m *mockPlugin) HealthCheck(ctx context.Context, config tracking.ToolConfig) bool { return true }
|
|
|
|
// TestNewRegistry tests registry creation
|
|
func TestNewRegistry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
registry := tracking.NewRegistry(logger)
|
|
require.NotNil(t, registry)
|
|
}
|
|
|
|
// TestRegisterAndGet tests plugin registration and retrieval
|
|
func TestRegisterAndGet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
registry := tracking.NewRegistry(logger)
|
|
|
|
plugin := &mockPlugin{name: "test-plugin"}
|
|
registry.Register(plugin)
|
|
|
|
retrieved, ok := registry.Get("test-plugin")
|
|
require.True(t, ok)
|
|
assert.Equal(t, "test-plugin", retrieved.Name())
|
|
|
|
// Nonexistent plugin
|
|
_, ok = registry.Get("nonexistent")
|
|
assert.False(t, ok)
|
|
}
|
|
|
|
// TestProvisionAllEmpty tests empty config provisioning
|
|
func TestProvisionAllEmpty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
registry := tracking.NewRegistry(logger)
|
|
|
|
env, err := registry.ProvisionAll(context.Background(), "task-1", nil)
|
|
require.NoError(t, err)
|
|
assert.Nil(t, env)
|
|
|
|
env, err = registry.ProvisionAll(context.Background(), "task-1", map[string]tracking.ToolConfig{})
|
|
require.NoError(t, err)
|
|
assert.Nil(t, env)
|
|
}
|
|
|
|
// TestProvisionAllDisabled tests disabled plugin handling
|
|
func TestProvisionAllDisabled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
registry := tracking.NewRegistry(logger)
|
|
|
|
plugin := &mockPlugin{name: "disabled-plugin"}
|
|
registry.Register(plugin)
|
|
|
|
configs := map[string]tracking.ToolConfig{
|
|
"disabled-plugin": {Enabled: false, Mode: tracking.ModeSidecar},
|
|
}
|
|
|
|
env, err := registry.ProvisionAll(context.Background(), "task-1", configs)
|
|
require.NoError(t, err)
|
|
assert.Empty(t, env)
|
|
assert.False(t, plugin.provisioned)
|
|
}
|
|
|
|
// TestProvisionAllUnregistered tests unregistered plugin error
|
|
func TestProvisionAllUnregistered(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
registry := tracking.NewRegistry(logger)
|
|
|
|
configs := map[string]tracking.ToolConfig{
|
|
"unregistered": {Enabled: true, Mode: tracking.ModeSidecar},
|
|
}
|
|
|
|
_, err := registry.ProvisionAll(context.Background(), "task-1", configs)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "not registered")
|
|
}
|
|
|
|
// TestTeardownAll tests plugin teardown
|
|
func TestTeardownAll(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
registry := tracking.NewRegistry(logger)
|
|
|
|
plugin := &mockPlugin{name: "teardown-test"}
|
|
registry.Register(plugin)
|
|
|
|
// Provision first
|
|
configs := map[string]tracking.ToolConfig{
|
|
"teardown-test": {Enabled: true, Mode: tracking.ModeSidecar},
|
|
}
|
|
_, err := registry.ProvisionAll(context.Background(), "task-1", configs)
|
|
require.NoError(t, err)
|
|
|
|
// Then teardown
|
|
registry.TeardownAll(context.Background(), "task-1")
|
|
assert.True(t, plugin.tornDown)
|
|
|
|
// Teardown nonexistent task should not panic
|
|
registry.TeardownAll(context.Background(), "nonexistent-task")
|
|
}
|
|
|
|
// TestNewPortAllocator tests allocator creation
|
|
func TestNewPortAllocator(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
start int
|
|
end int
|
|
wantStart int
|
|
wantEnd int
|
|
}{
|
|
{"valid range", 5000, 5100, 5000, 5100},
|
|
{"zero values", 0, 0, 5500, 5600},
|
|
{"invalid range", 100, 50, 5500, 5600},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
allocator := tracking.NewPortAllocator(tc.start, tc.end)
|
|
require.NotNil(t, allocator)
|
|
|
|
port, err := allocator.Allocate()
|
|
require.NoError(t, err)
|
|
assert.GreaterOrEqual(t, port, tc.wantStart)
|
|
assert.Less(t, port, tc.wantEnd)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestPortAllocatorAllocate tests port allocation
|
|
func TestPortAllocatorAllocate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
allocator := tracking.NewPortAllocator(8000, 8003)
|
|
|
|
// Allocate all ports
|
|
ports := make([]int, 0, 3)
|
|
for range 3 {
|
|
port, err := allocator.Allocate()
|
|
require.NoError(t, err)
|
|
ports = append(ports, port)
|
|
}
|
|
|
|
// Should be unique
|
|
assert.Len(t, ports, 3)
|
|
assert.NotEqual(t, ports[0], ports[1])
|
|
|
|
// No more ports available
|
|
_, err := allocator.Allocate()
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// TestPortAllocatorRelease tests port release and reuse
|
|
func TestPortAllocatorRelease(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
allocator := tracking.NewPortAllocator(9000, 9005)
|
|
|
|
port1, err := allocator.Allocate()
|
|
require.NoError(t, err)
|
|
|
|
// Allocate another port to move next pointer forward
|
|
port2, err := allocator.Allocate()
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, port1, port2)
|
|
|
|
// Release the first port
|
|
allocator.Release(port1)
|
|
|
|
// Allocate again - should eventually get the released port back
|
|
// (after scanning through other ports)
|
|
var foundReleased bool
|
|
for range 10 {
|
|
p, err := allocator.Allocate()
|
|
require.NoError(t, err)
|
|
if p == port1 {
|
|
foundReleased = true
|
|
break
|
|
}
|
|
// Release it to keep scanning
|
|
allocator.Release(p)
|
|
}
|
|
assert.True(t, foundReleased, "Should eventually get released port back")
|
|
}
|
|
|
|
// TestStringSetting tests settings extraction
|
|
func TestStringSetting(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
settings map[string]any
|
|
key string
|
|
want string
|
|
}{
|
|
{"valid string", map[string]any{"key": "value"}, "key", "value"},
|
|
{"nil settings", nil, "key", ""},
|
|
{"missing key", map[string]any{"other": "value"}, "key", ""},
|
|
{"non-string value", map[string]any{"key": 123}, "key", ""},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got := tracking.StringSetting(tc.settings, tc.key)
|
|
assert.Equal(t, tc.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestToolModeConstants tests mode constants
|
|
func TestToolModeConstants(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, tracking.ToolMode("sidecar"), tracking.ModeSidecar)
|
|
assert.Equal(t, tracking.ToolMode("remote"), tracking.ModeRemote)
|
|
assert.Equal(t, tracking.ToolMode("disabled"), tracking.ModeDisabled)
|
|
}
|
|
|
|
// TestToolConfigStructure tests config fields
|
|
func TestToolConfigStructure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
config := tracking.ToolConfig{
|
|
Settings: map[string]any{"port": 8080},
|
|
Mode: tracking.ModeSidecar,
|
|
Enabled: true,
|
|
}
|
|
|
|
assert.True(t, config.Enabled)
|
|
assert.Equal(t, tracking.ModeSidecar, config.Mode)
|
|
}
|