fetch_ml/internal/tracking/plugins/wandb.go

69 lines
1.6 KiB
Go

package plugins
import (
"context"
"fmt"
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
)
// WandbPlugin forwards credentials to the task container without provisioning a sidecar.
type WandbPlugin struct{}
// NewWandbPlugin constructs a Wandb plugin.
func NewWandbPlugin() *WandbPlugin {
return &WandbPlugin{}
}
// Name returns the plugin name.
func (w *WandbPlugin) Name() string {
return "wandb"
}
// ProvisionSidecar validates configuration and returns environment variables.
func (w *WandbPlugin) ProvisionSidecar(
_ context.Context,
_ string,
config trackingpkg.ToolConfig,
) (map[string]string, error) {
if config.Mode == trackingpkg.ModeDisabled {
return nil, nil
}
apiKey := trackingpkg.StringSetting(config.Settings, "api_key")
project := trackingpkg.StringSetting(config.Settings, "project")
entity := trackingpkg.StringSetting(config.Settings, "entity")
if config.Mode == trackingpkg.ModeRemote && apiKey == "" {
return nil, fmt.Errorf("wandb remote mode requires api_key")
}
env := map[string]string{}
if apiKey != "" {
env["WANDB_API_KEY"] = apiKey
}
if project != "" {
env["WANDB_PROJECT"] = project
}
if entity != "" {
env["WANDB_ENTITY"] = entity
}
return env, nil
}
// Teardown is a no-op for Wandb.
func (w *WandbPlugin) Teardown(context.Context, string) error {
return nil
}
// HealthCheck ensures required params exist.
func (w *WandbPlugin) HealthCheck(_ context.Context, config trackingpkg.ToolConfig) bool {
if !config.Enabled {
return true
}
if config.Mode == trackingpkg.ModeRemote {
return trackingpkg.StringSetting(config.Settings, "api_key") != ""
}
return true
}