fetch_ml/cli/src/config.zig
Jeremie Fraeys 2c596038b5
refactor(cli): update build system and config for local mode
- Update Makefile with build-sqlite target matching rsync pattern
- Fix build.zig to handle SQLite assets and dataset_hash linking
- Add SQLite asset detection mirroring rsync binary detection
- Update CLI README with local mode documentation
- Restructure rsync assets into rsync/ subdirectory
- Remove obsolete files (fix_arraylist.sh, old rsync_placeholder.bin)
- Add build_rsync.sh script to fetch/build rsync from source
2026-02-20 15:50:52 -05:00

392 lines
15 KiB
Zig

const std = @import("std");
const security = @import("security.zig");
/// URI-based configuration for FetchML
/// Supports: sqlite:///path/to.db or wss://server.com/ws
pub const Config = struct {
// Primary storage URI for local mode
tracking_uri: []const u8,
// Artifacts directory (for local storage)
artifact_path: []const u8,
// Sync target URI (for pushing local runs to server)
sync_uri: []const u8,
// Legacy server config (for runner mode)
worker_host: []const u8,
worker_user: []const u8,
worker_base: []const u8,
worker_port: u16,
api_key: []const u8,
// Default resource requests
default_cpu: u8,
default_memory: u8,
default_gpu: u8,
default_gpu_memory: ?[]const u8,
// CLI behavior defaults
default_dry_run: bool,
default_validate: bool,
default_json: bool,
default_priority: u8,
/// Check if this is local mode (sqlite://) or runner mode (wss://)
pub fn isLocalMode(self: Config) bool {
return std.mem.startsWith(u8, self.tracking_uri, "sqlite://");
}
/// Get the database path from tracking_uri (removes sqlite:// prefix)
pub fn getDBPath(self: Config, allocator: std.mem.Allocator) ![]const u8 {
const prefix = "sqlite://";
if (!std.mem.startsWith(u8, self.tracking_uri, prefix)) {
return error.InvalidTrackingURI;
}
const path = self.tracking_uri[prefix.len..];
// Handle ~ expansion for home directory
if (path.len > 0 and path[0] == '~') {
const home = std.posix.getenv("HOME") orelse return error.NoHomeDir;
return std.fmt.allocPrint(allocator, "{s}{s}", .{ home, path[1..] });
}
return allocator.dupe(u8, path);
}
pub fn validate(self: Config) !void {
// Only validate server config if not in local mode
if (!self.isLocalMode()) {
// Validate host
if (self.worker_host.len == 0) {
return error.EmptyHost;
}
// Validate port range
if (self.worker_port == 0 or self.worker_port > 65535) {
return error.InvalidPort;
}
// Validate API key presence
if (self.api_key.len == 0) {
return error.EmptyAPIKey;
}
// Validate base path
if (self.worker_base.len == 0) {
return error.EmptyBasePath;
}
}
}
/// Load config with priority: CLI > Env > Project > Global > Default
pub fn loadWithOverrides(allocator: std.mem.Allocator, cli_tracking_uri: ?[]const u8, cli_artifact_path: ?[]const u8, cli_sync_uri: ?[]const u8) !Config {
// Start with defaults
var config = try loadDefaults(allocator);
// Priority 4: Apply global config if exists
if (try loadGlobalConfig(allocator)) |global| {
config.apply(global);
config.deinitGlobal(allocator, global);
}
// Priority 3: Apply project config if exists
if (try loadProjectConfig(allocator)) |project| {
config.apply(project);
config.deinitGlobal(allocator, project);
}
// Priority 2: Apply environment variables
config.applyEnv(allocator);
// Priority 1: Apply CLI overrides
if (cli_tracking_uri) |uri| {
allocator.free(config.tracking_uri);
config.tracking_uri = try allocator.dupe(u8, uri);
}
if (cli_artifact_path) |path| {
allocator.free(config.artifact_path);
config.artifact_path = try allocator.dupe(u8, path);
}
if (cli_sync_uri) |uri| {
allocator.free(config.sync_uri);
config.sync_uri = try allocator.dupe(u8, uri);
}
return config;
}
/// Legacy load function (no overrides)
pub fn load(allocator: std.mem.Allocator) !Config {
return loadWithOverrides(allocator, null, null, null);
}
/// Load default configuration
fn loadDefaults(allocator: std.mem.Allocator) !Config {
return Config{
.tracking_uri = try allocator.dupe(u8, "sqlite://./fetch_ml.db"),
.artifact_path = try allocator.dupe(u8, "./experiments/"),
.sync_uri = try allocator.dupe(u8, ""),
.worker_host = try allocator.dupe(u8, ""),
.worker_user = try allocator.dupe(u8, ""),
.worker_base = try allocator.dupe(u8, ""),
.worker_port = 22,
.api_key = try allocator.dupe(u8, ""),
.default_cpu = 2,
.default_memory = 8,
.default_gpu = 0,
.default_gpu_memory = null,
.default_dry_run = false,
.default_validate = false,
.default_json = false,
.default_priority = 5,
};
}
fn loadFromFile(allocator: std.mem.Allocator, file: std.fs.File) !Config {
const content = try file.readToEndAlloc(allocator, 1024 * 1024);
defer allocator.free(content);
// Simple TOML parser - parse key=value pairs
var config = Config{
.worker_host = "",
.worker_user = "",
.worker_base = "",
.worker_port = 22,
.api_key = "",
.default_cpu = 2,
.default_memory = 8,
.default_gpu = 0,
.default_gpu_memory = null,
.default_dry_run = false,
.default_validate = false,
.default_json = false,
.default_priority = 5,
};
var lines = std.mem.splitScalar(u8, content, '\n');
while (lines.next()) |line| {
const trimmed = std.mem.trim(u8, line, " \t\r");
if (trimmed.len == 0 or trimmed[0] == '#') continue;
var parts = std.mem.splitScalar(u8, trimmed, '=');
const key = std.mem.trim(u8, parts.next() orelse continue, " \t");
const value_raw = std.mem.trim(u8, parts.next() orelse continue, " \t");
// Remove quotes
const value = if (value_raw.len >= 2 and value_raw[0] == '"' and value_raw[value_raw.len - 1] == '"')
value_raw[1 .. value_raw.len - 1]
else
value_raw;
if (std.mem.eql(u8, key, "tracking_uri")) {
config.tracking_uri = try allocator.dupe(u8, value);
} else if (std.mem.eql(u8, key, "artifact_path")) {
config.artifact_path = try allocator.dupe(u8, value);
} else if (std.mem.eql(u8, key, "sync_uri")) {
config.sync_uri = try allocator.dupe(u8, value);
} else if (std.mem.eql(u8, key, "worker_host")) {
config.worker_host = try allocator.dupe(u8, value);
} else if (std.mem.eql(u8, key, "worker_user")) {
config.worker_user = try allocator.dupe(u8, value);
} else if (std.mem.eql(u8, key, "worker_base")) {
config.worker_base = try allocator.dupe(u8, value);
} else if (std.mem.eql(u8, key, "worker_port")) {
config.worker_port = try std.fmt.parseInt(u16, value, 10);
} else if (std.mem.eql(u8, key, "api_key")) {
config.api_key = try allocator.dupe(u8, value);
} else if (std.mem.eql(u8, key, "default_cpu")) {
config.default_cpu = try std.fmt.parseInt(u8, value, 10);
} else if (std.mem.eql(u8, key, "default_memory")) {
config.default_memory = try std.fmt.parseInt(u8, value, 10);
} else if (std.mem.eql(u8, key, "default_gpu")) {
config.default_gpu = try std.fmt.parseInt(u8, value, 10);
} else if (std.mem.eql(u8, key, "default_gpu_memory")) {
if (value.len > 0) {
config.default_gpu_memory = try allocator.dupe(u8, value);
}
} else if (std.mem.eql(u8, key, "default_dry_run")) {
config.default_dry_run = std.mem.eql(u8, value, "true");
} else if (std.mem.eql(u8, key, "default_validate")) {
config.default_validate = std.mem.eql(u8, value, "true");
} else if (std.mem.eql(u8, key, "default_json")) {
config.default_json = std.mem.eql(u8, value, "true");
} else if (std.mem.eql(u8, key, "default_priority")) {
config.default_priority = try std.fmt.parseInt(u8, value, 10);
}
}
return config;
}
pub fn save(self: Config, allocator: std.mem.Allocator) !void {
const home = std.posix.getenv("HOME") orelse return error.NoHomeDir;
// Create .ml directory
const ml_dir = try std.fmt.allocPrint(allocator, "{s}/.ml", .{home});
defer allocator.free(ml_dir);
std.fs.makeDirAbsolute(ml_dir) catch |err| {
if (err != error.PathAlreadyExists) return err;
};
const config_path = try std.fmt.allocPrint(allocator, "{s}/config.toml", .{ml_dir});
defer allocator.free(config_path);
const file = try std.fs.createFileAbsolute(config_path, .{});
defer file.close();
const writer = file.writer();
try writer.print("# FetchML Configuration\n", .{});
try writer.print("tracking_uri = \"{s}\"\n", .{self.tracking_uri});
try writer.print("artifact_path = \"{s}\"\n", .{self.artifact_path});
try writer.print("sync_uri = \"{s}\"\n", .{self.sync_uri});
try writer.print("\n# Server config (for runner mode)\n", .{});
try writer.print("worker_host = \"{s}\"\n", .{self.worker_host});
try writer.print("worker_user = \"{s}\"\n", .{self.worker_user});
try writer.print("worker_base = \"{s}\"\n", .{self.worker_base});
try writer.print("worker_port = {d}\n", .{self.worker_port});
try writer.print("api_key = \"{s}\"\n", .{self.api_key});
try writer.print("\n# Default resource requests\n", .{});
try writer.print("default_cpu = {d}\n", .{self.default_cpu});
try writer.print("default_memory = {d}\n", .{self.default_memory});
try writer.print("default_gpu = {d}\n", .{self.default_gpu});
if (self.default_gpu_memory) |gpu_mem| {
try writer.print("default_gpu_memory = \"{s}\"\n", .{gpu_mem});
}
try writer.print("\n# CLI behavior defaults\n", .{});
try writer.print("default_dry_run = {s}\n", .{if (self.default_dry_run) "true" else "false"});
try writer.print("default_validate = {s}\n", .{if (self.default_validate) "true" else "false"});
try writer.print("default_json = {s}\n", .{if (self.default_json) "true" else "false"});
try writer.print("default_priority = {d}\n", .{self.default_priority});
}
pub fn deinit(self: *Config, allocator: std.mem.Allocator) void {
allocator.free(self.tracking_uri);
allocator.free(self.artifact_path);
allocator.free(self.sync_uri);
allocator.free(self.worker_host);
allocator.free(self.worker_user);
allocator.free(self.worker_base);
allocator.free(self.api_key);
if (self.default_gpu_memory) |gpu_mem| {
allocator.free(gpu_mem);
}
}
/// Apply settings from another config (for layering)
fn apply(self: *Config, other: Config) void {
if (other.tracking_uri.len > 0) {
self.tracking_uri = other.tracking_uri;
}
if (other.artifact_path.len > 0) {
self.artifact_path = other.artifact_path;
}
if (other.sync_uri.len > 0) {
self.sync_uri = other.sync_uri;
}
if (other.worker_host.len > 0) {
self.worker_host = other.worker_host;
}
if (other.worker_user.len > 0) {
self.worker_user = other.worker_user;
}
if (other.worker_base.len > 0) {
self.worker_base = other.worker_base;
}
if (other.worker_port != 22) {
self.worker_port = other.worker_port;
}
if (other.api_key.len > 0) {
self.api_key = other.api_key;
}
}
/// Deinit a config that was loaded temporarily
fn deinitGlobal(self: Config, allocator: std.mem.Allocator, other: Config) void {
_ = self;
allocator.free(other.tracking_uri);
allocator.free(other.artifact_path);
allocator.free(other.sync_uri);
allocator.free(other.worker_host);
allocator.free(other.worker_user);
allocator.free(other.worker_base);
allocator.free(other.api_key);
if (other.default_gpu_memory) |gpu_mem| {
allocator.free(gpu_mem);
}
}
/// Apply environment variable overrides
fn applyEnv(self: *Config, allocator: std.mem.Allocator) void {
// FETCHML_* environment variables for URI-based config
if (std.posix.getenv("FETCHML_TRACKING_URI")) |uri| {
allocator.free(self.tracking_uri);
self.tracking_uri = allocator.dupe(u8, uri) catch self.tracking_uri;
}
if (std.posix.getenv("FETCHML_ARTIFACT_PATH")) |path| {
allocator.free(self.artifact_path);
self.artifact_path = allocator.dupe(u8, path) catch self.artifact_path;
}
if (std.posix.getenv("FETCHML_SYNC_URI")) |uri| {
allocator.free(self.sync_uri);
self.sync_uri = allocator.dupe(u8, uri) catch self.sync_uri;
}
// Legacy FETCH_ML_CLI_* variables
if (std.posix.getenv("FETCH_ML_CLI_HOST")) |host| {
allocator.free(self.worker_host);
self.worker_host = allocator.dupe(u8, host) catch self.worker_host;
}
if (std.posix.getenv("FETCH_ML_CLI_USER")) |user| {
allocator.free(self.worker_user);
self.worker_user = allocator.dupe(u8, user) catch self.worker_user;
}
if (std.posix.getenv("FETCH_ML_CLI_BASE")) |base| {
allocator.free(self.worker_base);
self.worker_base = allocator.dupe(u8, base) catch self.worker_base;
}
if (std.posix.getenv("FETCH_ML_CLI_PORT")) |port_str| {
if (std.fmt.parseInt(u16, port_str, 10)) |port| {
self.worker_port = port;
} else |_| {}
}
if (std.posix.getenv("FETCH_ML_CLI_API_KEY")) |api_key| {
allocator.free(self.api_key);
self.api_key = allocator.dupe(u8, api_key) catch self.api_key;
}
}
/// Get WebSocket URL for connecting to the server
pub fn getWebSocketUrl(self: Config, allocator: std.mem.Allocator) ![]u8 {
const protocol = if (self.worker_port == 443) "wss" else "ws";
return std.fmt.allocPrint(allocator, "{s}://{s}:{d}/ws", .{
protocol, self.worker_host, self.worker_port,
});
}
};
/// Load global config from ~/.ml/config.toml
fn loadGlobalConfig(allocator: std.mem.Allocator) !?Config {
const home = std.posix.getenv("HOME") orelse return null;
const config_path = try std.fmt.allocPrint(allocator, "{s}/.ml/config.toml", .{home});
defer allocator.free(config_path);
const file = std.fs.openFileAbsolute(config_path, .{}) catch |err| {
if (err == error.FileNotFound) return null;
return err;
};
defer file.close();
return try Config.loadFromFile(allocator, file);
}
/// Load project config from .fetchml/config.toml in CWD
fn loadProjectConfig(allocator: std.mem.Allocator) !?Config {
const file = std.fs.openFileAbsolute(".fetchml/config.toml", .{}) catch |err| {
if (err == error.FileNotFound) return null;
return err;
};
defer file.close();
return try Config.loadFromFile(allocator, file);
}