fetch_ml/internal/tracking/plugins/tensorboard.go

170 lines
4 KiB
Go

package plugins
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/logging"
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
)
// TensorBoardOptions configure the TensorBoard plugin.
type TensorBoardOptions struct {
Image string
LogBasePath string
PortAllocator *trackingpkg.PortAllocator
}
type tensorboardSidecar struct {
containerID string
port int
}
// TensorBoardPlugin exposes training logs through TensorBoard.
type TensorBoardPlugin struct {
logger *logging.Logger
podman *container.PodmanManager
opts TensorBoardOptions
mu sync.Mutex
sidecars map[string]*tensorboardSidecar
}
// NewTensorBoardPlugin constructs a TensorBoard plugin instance.
func NewTensorBoardPlugin(
logger *logging.Logger,
podman *container.PodmanManager,
opts TensorBoardOptions,
) (*TensorBoardPlugin, error) {
if podman == nil {
return nil, fmt.Errorf("podman manager is required for TensorBoard plugin")
}
if opts.Image == "" {
opts.Image = "tensorflow/tensorflow:2.17.0"
}
if opts.LogBasePath == "" {
return nil, fmt.Errorf("log base path is required for TensorBoard plugin")
}
if opts.PortAllocator == nil {
opts.PortAllocator = trackingpkg.NewPortAllocator(5700, 5900)
}
return &TensorBoardPlugin{
logger: logger,
podman: podman,
opts: opts,
sidecars: make(map[string]*tensorboardSidecar),
}, nil
}
// Name returns the plugin name.
func (t *TensorBoardPlugin) Name() string {
return "tensorboard"
}
// ProvisionSidecar starts TensorBoard to tail a task's log directory.
func (t *TensorBoardPlugin) ProvisionSidecar(
ctx context.Context,
taskID string,
config trackingpkg.ToolConfig,
) (map[string]string, error) {
if config.Mode == trackingpkg.ModeRemote || config.Mode == trackingpkg.ModeDisabled {
return nil, nil
}
jobName := trackingpkg.StringSetting(config.Settings, "job_name")
if jobName == "" {
jobName = taskID
}
logDir := filepath.Join(t.opts.LogBasePath, jobName)
if err := os.MkdirAll(logDir, 0750); err != nil {
return nil, fmt.Errorf("failed to ensure tensorboard log dir: %w", err)
}
port, err := t.opts.PortAllocator.Allocate()
if err != nil {
return nil, err
}
containerName := fmt.Sprintf("tensorboard-%s", taskID)
cmd := []string{
"tensorboard",
"--logdir", "/logs",
"--host", "0.0.0.0",
"--port", "6006",
}
cfg := &container.ContainerConfig{
Name: containerName,
Image: t.opts.Image,
Command: cmd,
Volumes: map[string]string{
logDir: "/logs:ro",
},
Ports: map[int]int{
port: 6006,
},
Network: container.NetworkConfig{
AllowNetwork: true,
},
}
containerID, err := t.podman.StartContainer(ctx, cfg)
if err != nil {
t.opts.PortAllocator.Release(port)
return nil, err
}
t.mu.Lock()
t.sidecars[taskID] = &tensorboardSidecar{
containerID: containerID,
port: port,
}
t.mu.Unlock()
return map[string]string{
"TENSORBOARD_URL": fmt.Sprintf("http://127.0.0.1:%d", port),
"TENSORBOARD_HOST_LOG_DIR": logDir,
}, nil
}
// Teardown stops the TensorBoard sidecar.
func (t *TensorBoardPlugin) Teardown(ctx context.Context, taskID string) error {
t.mu.Lock()
sidecar, ok := t.sidecars[taskID]
if ok {
delete(t.sidecars, taskID)
}
t.mu.Unlock()
if !ok {
return nil
}
if sidecar.port != 0 {
t.opts.PortAllocator.Release(sidecar.port)
}
if sidecar.containerID == "" {
return nil
}
if err := t.podman.StopContainer(ctx, sidecar.containerID); err != nil && t.logger != nil {
t.logger.Warn("failed to stop tensorboard container", "task_id", taskID, "error", err)
}
if err := t.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && t.logger != nil {
t.logger.Warn("failed to remove tensorboard container", "task_id", taskID, "error", err)
}
return nil
}
// HealthCheck currently returns true; extend with HTTP checks later.
func (t *TensorBoardPlugin) HealthCheck(context.Context, trackingpkg.ToolConfig) bool {
return true
}