69 lines
1.6 KiB
Go
69 lines
1.6 KiB
Go
package plugins
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
trackingpkg "github.com/jfraeys/fetch_ml/internal/tracking"
|
|
)
|
|
|
|
// WandbPlugin forwards credentials to the task container without provisioning a sidecar.
|
|
type WandbPlugin struct{}
|
|
|
|
// NewWandbPlugin constructs a Wandb plugin.
|
|
func NewWandbPlugin() *WandbPlugin {
|
|
return &WandbPlugin{}
|
|
}
|
|
|
|
// Name returns the plugin name.
|
|
func (w *WandbPlugin) Name() string {
|
|
return "wandb"
|
|
}
|
|
|
|
// ProvisionSidecar validates configuration and returns environment variables.
|
|
func (w *WandbPlugin) ProvisionSidecar(
|
|
_ context.Context,
|
|
_ string,
|
|
config trackingpkg.ToolConfig,
|
|
) (map[string]string, error) {
|
|
if config.Mode == trackingpkg.ModeDisabled {
|
|
return nil, nil
|
|
}
|
|
|
|
apiKey := trackingpkg.StringSetting(config.Settings, "api_key")
|
|
project := trackingpkg.StringSetting(config.Settings, "project")
|
|
entity := trackingpkg.StringSetting(config.Settings, "entity")
|
|
|
|
if config.Mode == trackingpkg.ModeRemote && apiKey == "" {
|
|
return nil, fmt.Errorf("wandb remote mode requires api_key")
|
|
}
|
|
|
|
env := map[string]string{}
|
|
if apiKey != "" {
|
|
env["WANDB_API_KEY"] = apiKey
|
|
}
|
|
if project != "" {
|
|
env["WANDB_PROJECT"] = project
|
|
}
|
|
if entity != "" {
|
|
env["WANDB_ENTITY"] = entity
|
|
}
|
|
|
|
return env, nil
|
|
}
|
|
|
|
// Teardown is a no-op for Wandb.
|
|
func (w *WandbPlugin) Teardown(context.Context, string) error {
|
|
return nil
|
|
}
|
|
|
|
// HealthCheck ensures required params exist.
|
|
func (w *WandbPlugin) HealthCheck(_ context.Context, config trackingpkg.ToolConfig) bool {
|
|
if !config.Enabled {
|
|
return true
|
|
}
|
|
if config.Mode == trackingpkg.ModeRemote {
|
|
return trackingpkg.StringSetting(config.Settings, "api_key") != ""
|
|
}
|
|
return true
|
|
}
|