Update utility modules: - File utilities with secure file operations - Environment pool with resource tracking - Error types with scheduler error categories - Logging with audit context support - Network/SSH with connection pooling - Privacy/PII handling with tenant boundaries - Resource manager with scheduler allocation - Security monitor with audit integration - Tracking plugins (MLflow, TensorBoard) with auth - Crypto signing with tenant keys - Database init with multi-user support
169 lines
4 KiB
Go
169 lines
4 KiB
Go
package plugins
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
|
|
)
|
|
|
|
// TensorBoardOptions configure the TensorBoard plugin.
|
|
type TensorBoardOptions struct {
|
|
PortAllocator *trackingpkg.PortAllocator
|
|
Image string
|
|
LogBasePath string
|
|
}
|
|
|
|
type tensorboardSidecar struct {
|
|
containerID string
|
|
port int
|
|
}
|
|
|
|
// TensorBoardPlugin exposes training logs through TensorBoard.
|
|
type TensorBoardPlugin struct {
|
|
logger *logging.Logger
|
|
podman *container.PodmanManager
|
|
sidecars map[string]*tensorboardSidecar
|
|
opts TensorBoardOptions
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewTensorBoardPlugin constructs a TensorBoard plugin instance.
|
|
func NewTensorBoardPlugin(
|
|
logger *logging.Logger,
|
|
podman *container.PodmanManager,
|
|
opts TensorBoardOptions,
|
|
) (*TensorBoardPlugin, error) {
|
|
if podman == nil {
|
|
return nil, fmt.Errorf("podman manager is required for TensorBoard plugin")
|
|
}
|
|
if opts.Image == "" {
|
|
opts.Image = "tensorflow/tensorflow:2.17.0"
|
|
}
|
|
if opts.LogBasePath == "" {
|
|
return nil, fmt.Errorf("log base path is required for TensorBoard plugin")
|
|
}
|
|
if opts.PortAllocator == nil {
|
|
opts.PortAllocator = trackingpkg.NewPortAllocator(5700, 5900)
|
|
}
|
|
|
|
return &TensorBoardPlugin{
|
|
logger: logger,
|
|
podman: podman,
|
|
opts: opts,
|
|
sidecars: make(map[string]*tensorboardSidecar),
|
|
}, nil
|
|
}
|
|
|
|
// Name returns the plugin name.
|
|
func (t *TensorBoardPlugin) Name() string {
|
|
return "tensorboard"
|
|
}
|
|
|
|
// ProvisionSidecar starts TensorBoard to tail a task's log directory.
|
|
func (t *TensorBoardPlugin) ProvisionSidecar(
|
|
ctx context.Context,
|
|
taskID string,
|
|
config trackingpkg.ToolConfig,
|
|
) (map[string]string, error) {
|
|
if config.Mode == trackingpkg.ModeRemote || config.Mode == trackingpkg.ModeDisabled {
|
|
return nil, nil
|
|
}
|
|
|
|
jobName := trackingpkg.StringSetting(config.Settings, "job_name")
|
|
if jobName == "" {
|
|
jobName = taskID
|
|
}
|
|
|
|
logDir := filepath.Join(t.opts.LogBasePath, jobName)
|
|
if err := os.MkdirAll(logDir, 0750); err != nil {
|
|
return nil, fmt.Errorf("failed to ensure tensorboard log dir: %w", err)
|
|
}
|
|
|
|
port, err := t.opts.PortAllocator.Allocate()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
containerName := fmt.Sprintf("tensorboard-%s", taskID)
|
|
cmd := []string{
|
|
"tensorboard",
|
|
"--logdir", "/logs",
|
|
"--host", "0.0.0.0",
|
|
"--port", "6006",
|
|
}
|
|
|
|
cfg := &container.ContainerConfig{
|
|
Name: containerName,
|
|
Image: t.opts.Image,
|
|
Command: cmd,
|
|
Volumes: map[string]string{
|
|
logDir: "/logs:ro",
|
|
},
|
|
Ports: map[int]int{
|
|
port: 6006,
|
|
},
|
|
Network: container.NetworkConfig{
|
|
AllowNetwork: true,
|
|
},
|
|
}
|
|
|
|
containerID, err := t.podman.StartContainer(ctx, cfg)
|
|
if err != nil {
|
|
t.opts.PortAllocator.Release(port)
|
|
return nil, err
|
|
}
|
|
|
|
t.mu.Lock()
|
|
t.sidecars[taskID] = &tensorboardSidecar{
|
|
containerID: containerID,
|
|
port: port,
|
|
}
|
|
t.mu.Unlock()
|
|
|
|
return map[string]string{
|
|
"TENSORBOARD_URL": fmt.Sprintf("http://127.0.0.1:%d", port),
|
|
"TENSORBOARD_HOST_LOG_DIR": logDir,
|
|
}, nil
|
|
}
|
|
|
|
// Teardown stops the TensorBoard sidecar.
|
|
func (t *TensorBoardPlugin) Teardown(ctx context.Context, taskID string) error {
|
|
t.mu.Lock()
|
|
sidecar, ok := t.sidecars[taskID]
|
|
if ok {
|
|
delete(t.sidecars, taskID)
|
|
}
|
|
t.mu.Unlock()
|
|
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if sidecar.port != 0 {
|
|
t.opts.PortAllocator.Release(sidecar.port)
|
|
}
|
|
|
|
if sidecar.containerID == "" {
|
|
return nil
|
|
}
|
|
|
|
if err := t.podman.StopContainer(ctx, sidecar.containerID); err != nil && t.logger != nil {
|
|
t.logger.Warn("failed to stop tensorboard container", "task_id", taskID, "error", err)
|
|
}
|
|
if err := t.podman.RemoveContainer(ctx, sidecar.containerID); err != nil && t.logger != nil {
|
|
t.logger.Warn("failed to remove tensorboard container", "task_id", taskID, "error", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HealthCheck currently returns true; extend with HTTP checks later.
|
|
func (t *TensorBoardPlugin) HealthCheck(context.Context, trackingpkg.ToolConfig) bool {
|
|
return true
|
|
}
|