package plugins import ( "context" "fmt" "os" "path/filepath" "sync" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/logging" trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking" ) // TensorBoardOptions configure the TensorBoard plugin. type TensorBoardOptions struct { PortAllocator *trackingpkg.PortAllocator Image string LogBasePath string } type tensorboardSidecar struct { containerID string port int } // TensorBoardPlugin exposes training logs through TensorBoard. type TensorBoardPlugin struct { logger *logging.Logger podman *container.PodmanManager sidecars map[string]*tensorboardSidecar opts TensorBoardOptions mu sync.Mutex } // NewTensorBoardPlugin constructs a TensorBoard plugin instance. func NewTensorBoardPlugin( logger *logging.Logger, podman *container.PodmanManager, opts TensorBoardOptions, ) (*TensorBoardPlugin, error) { if podman == nil { return nil, fmt.Errorf("podman manager is required for TensorBoard plugin") } if opts.Image == "" { opts.Image = "tensorflow/tensorflow:2.17.0" } if opts.LogBasePath == "" { return nil, fmt.Errorf("log base path is required for TensorBoard plugin") } if opts.PortAllocator == nil { opts.PortAllocator = trackingpkg.NewPortAllocator(5700, 5900) } return &TensorBoardPlugin{ logger: logger, podman: podman, opts: opts, sidecars: make(map[string]*tensorboardSidecar), }, nil } // Name returns the plugin name. func (t *TensorBoardPlugin) Name() string { return "tensorboard" } // ProvisionSidecar starts TensorBoard to tail a task's log directory. func (t *TensorBoardPlugin) ProvisionSidecar( ctx context.Context, taskID string, config trackingpkg.ToolConfig, ) (map[string]string, error) { if config.Mode == trackingpkg.ModeRemote || config.Mode == trackingpkg.ModeDisabled { return nil, nil } jobName := trackingpkg.StringSetting(config.Settings, "job_name") if jobName == "" { jobName = taskID } logDir := filepath.Join(t.opts.LogBasePath, jobName) if err := os.MkdirAll(logDir, 0750); err != nil { return nil, fmt.Errorf("failed to ensure tensorboard log dir: %w", err) } port, err := t.opts.PortAllocator.Allocate() if err != nil { return nil, err } containerName := fmt.Sprintf("tensorboard-%s", taskID) cmd := []string{ "tensorboard", "--logdir", "/logs", "--host", "0.0.0.0", "--port", "6006", } cfg := &container.ContainerConfig{ Name: containerName, Image: t.opts.Image, Command: cmd, Volumes: map[string]string{ logDir: "/logs:ro", }, Ports: map[int]int{ port: 6006, }, Network: container.NetworkConfig{ AllowNetwork: true, }, } containerID, err := t.podman.StartContainer(ctx, cfg) if err != nil { t.opts.PortAllocator.Release(port) return nil, err } t.mu.Lock() t.sidecars[taskID] = &tensorboardSidecar{ containerID: containerID, port: port, } t.mu.Unlock() return map[string]string{ "TENSORBOARD_URL": fmt.Sprintf("http://127.0.0.1:%d", port), "TENSORBOARD_HOST_LOG_DIR": logDir, }, nil } // Teardown stops the TensorBoard sidecar. func (t *TensorBoardPlugin) Teardown(ctx context.Context, taskID string) error { t.mu.Lock() sidecar, ok := t.sidecars[taskID] if ok { delete(t.sidecars, taskID) } t.mu.Unlock() if !ok { return nil } if sidecar.port != 0 { t.opts.PortAllocator.Release(sidecar.port) } if sidecar.containerID == "" { return nil } if err := t.podman.StopContainer(ctx, sidecar.containerID); err != nil && t.logger != nil { t.logger.Warn("failed to stop tensorboard container", "task_id", taskID, "error", err) } if err := t.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && t.logger != nil { t.logger.Warn("failed to remove tensorboard container", "task_id", taskID, "error", err) } return nil } // HealthCheck currently returns true; extend with HTTP checks later. func (t *TensorBoardPlugin) HealthCheck(context.Context, trackingpkg.ToolConfig) bool { return true }