const std = @import("std"); const security = @import("security.zig"); pub const ExperimentConfig = struct { name: []const u8, entrypoint: []const u8, }; /// URI-based configuration for FetchML /// Supports: sqlite:///path/to.db or wss://server.com/ws pub const Config = struct { // Primary storage URI for local mode tracking_uri: []const u8, // Artifacts directory (for local storage) artifact_path: []const u8, // Sync target URI (for pushing local runs to server) sync_uri: []const u8, // Force local mode regardless of server config force_local: bool, // Experiment configuration ([experiment] section) experiment: ?ExperimentConfig, // Legacy server config (for runner mode) worker_host: []const u8, worker_user: []const u8, worker_base: []const u8, worker_port: u16, api_key: []const u8, // Default resource requests default_cpu: u8, default_memory: u8, default_gpu: u8, default_gpu_memory: ?[]const u8, // CLI behavior defaults default_dry_run: bool, default_validate: bool, default_json: bool, default_priority: u8, /// Check if this is local mode (sqlite://) or runner mode (wss://) pub fn isLocalMode(self: Config) bool { return std.mem.startsWith(u8, self.tracking_uri, "sqlite://"); } /// Get the database path from tracking_uri (removes sqlite:// prefix) pub fn getDBPath(self: Config, allocator: std.mem.Allocator) ![]const u8 { const prefix = "sqlite://"; if (!std.mem.startsWith(u8, self.tracking_uri, prefix)) { return error.InvalidTrackingURI; } const path = self.tracking_uri[prefix.len..]; // Handle ~ expansion for home directory if (path.len > 0 and path[0] == '~') { const home = std.posix.getenv("HOME") orelse return error.NoHomeDir; return std.fmt.allocPrint(allocator, "{s}{s}", .{ home, path[1..] }); } return allocator.dupe(u8, path); } pub fn validate(self: Config) !void { // Only validate server config if not in local mode if (!self.isLocalMode()) { // Validate host if (self.worker_host.len == 0) { return error.EmptyHost; } // Validate port range if (self.worker_port == 0 or self.worker_port > 65535) { return error.InvalidPort; } // Validate API key presence if (self.api_key.len == 0) { return error.EmptyAPIKey; } // Validate base path if (self.worker_base.len == 0) { return error.EmptyBasePath; } } } /// Load config with priority: CLI > Env > Project > Global > Default pub fn loadWithOverrides(allocator: std.mem.Allocator, cli_tracking_uri: ?[]const u8, cli_artifact_path: ?[]const u8, cli_sync_uri: ?[]const u8) !Config { // Start with defaults var config = try loadDefaults(allocator); // Priority 4: Apply global config if exists if (try loadGlobalConfig(allocator)) |global| { config.apply(global); config.deinitGlobal(allocator, global); } // Priority 3: Apply project config if exists if (try loadProjectConfig(allocator)) |project| { config.apply(project); config.deinitGlobal(allocator, project); } // Priority 2: Apply environment variables config.applyEnv(allocator); // Priority 1: Apply CLI overrides if (cli_tracking_uri) |uri| { allocator.free(config.tracking_uri); config.tracking_uri = try allocator.dupe(u8, uri); } if (cli_artifact_path) |path| { allocator.free(config.artifact_path); config.artifact_path = try allocator.dupe(u8, path); } if (cli_sync_uri) |uri| { allocator.free(config.sync_uri); config.sync_uri = try allocator.dupe(u8, uri); } return config; } /// Legacy load function (no overrides) pub fn load(allocator: std.mem.Allocator) !Config { return loadWithOverrides(allocator, null, null, null); } /// Load default configuration fn loadDefaults(allocator: std.mem.Allocator) !Config { return Config{ .tracking_uri = try allocator.dupe(u8, "sqlite://./fetch_ml.db"), .artifact_path = try allocator.dupe(u8, "./experiments/"), .sync_uri = try allocator.dupe(u8, ""), .force_local = false, .experiment = null, .worker_host = try allocator.dupe(u8, ""), .worker_user = try allocator.dupe(u8, ""), .worker_base = try allocator.dupe(u8, ""), .worker_port = 22, .api_key = try allocator.dupe(u8, ""), .default_cpu = 2, .default_memory = 8, .default_gpu = 0, .default_gpu_memory = null, .default_dry_run = false, .default_validate = false, .default_json = false, .default_priority = 5, }; } fn loadFromFile(allocator: std.mem.Allocator, file: std.fs.File) !Config { const content = try file.readToEndAlloc(allocator, 1024 * 1024); defer allocator.free(content); // Simple TOML parser - parse key=value pairs and [section] headers var config = Config{ .tracking_uri = "", .artifact_path = "", .sync_uri = "", .force_local = false, .experiment = null, .worker_host = "", .worker_user = "", .worker_base = "", .worker_port = 22, .api_key = "", .default_cpu = 2, .default_memory = 8, .default_gpu = 0, .default_gpu_memory = null, .default_dry_run = false, .default_validate = false, .default_json = false, .default_priority = 5, }; var current_section: []const u8 = "root"; var experiment_name: ?[]const u8 = null; var experiment_entrypoint: ?[]const u8 = null; var lines = std.mem.splitScalar(u8, content, '\n'); while (lines.next()) |line| { const trimmed = std.mem.trim(u8, line, " \t\r"); if (trimmed.len == 0 or trimmed[0] == '#') continue; // Check for section header [section] if (trimmed[0] == '[' and trimmed[trimmed.len - 1] == ']') { current_section = trimmed[1 .. trimmed.len - 1]; continue; } var parts = std.mem.splitScalar(u8, trimmed, '='); const key = std.mem.trim(u8, parts.next() orelse continue, " \t"); const value_raw = std.mem.trim(u8, parts.next() orelse continue, " \t"); // Remove quotes const value = if (value_raw.len >= 2 and value_raw[0] == '"' and value_raw[value_raw.len - 1] == '"') value_raw[1 .. value_raw.len - 1] else value_raw; // Parse based on current section if (std.mem.eql(u8, current_section, "experiment")) { if (std.mem.eql(u8, key, "name")) { experiment_name = try allocator.dupe(u8, value); } else if (std.mem.eql(u8, key, "entrypoint")) { experiment_entrypoint = try allocator.dupe(u8, value); } } else { // Root level keys if (std.mem.eql(u8, key, "tracking_uri")) { config.tracking_uri = try allocator.dupe(u8, value); } else if (std.mem.eql(u8, key, "artifact_path")) { config.artifact_path = try allocator.dupe(u8, value); } else if (std.mem.eql(u8, key, "sync_uri")) { config.sync_uri = try allocator.dupe(u8, value); } else if (std.mem.eql(u8, key, "force_local")) { config.force_local = std.mem.eql(u8, value, "true"); } else if (std.mem.eql(u8, key, "worker_host")) { config.worker_host = try allocator.dupe(u8, value); } else if (std.mem.eql(u8, key, "worker_user")) { config.worker_user = try allocator.dupe(u8, value); } else if (std.mem.eql(u8, key, "worker_base")) { config.worker_base = try allocator.dupe(u8, value); } else if (std.mem.eql(u8, key, "worker_port")) { config.worker_port = try std.fmt.parseInt(u16, value, 10); } else if (std.mem.eql(u8, key, "api_key")) { config.api_key = try allocator.dupe(u8, value); } else if (std.mem.eql(u8, key, "default_cpu")) { config.default_cpu = try std.fmt.parseInt(u8, value, 10); } else if (std.mem.eql(u8, key, "default_memory")) { config.default_memory = try std.fmt.parseInt(u8, value, 10); } else if (std.mem.eql(u8, key, "default_gpu")) { config.default_gpu = try std.fmt.parseInt(u8, value, 10); } else if (std.mem.eql(u8, key, "default_gpu_memory")) { if (value.len > 0) { config.default_gpu_memory = try allocator.dupe(u8, value); } } else if (std.mem.eql(u8, key, "default_dry_run")) { config.default_dry_run = std.mem.eql(u8, value, "true"); } else if (std.mem.eql(u8, key, "default_validate")) { config.default_validate = std.mem.eql(u8, value, "true"); } else if (std.mem.eql(u8, key, "default_json")) { config.default_json = std.mem.eql(u8, value, "true"); } else if (std.mem.eql(u8, key, "default_priority")) { config.default_priority = try std.fmt.parseInt(u8, value, 10); } } } // Create experiment config if both name and entrypoint are set if (experiment_name != null and experiment_entrypoint != null) { config.experiment = ExperimentConfig{ .name = experiment_name.?, .entrypoint = experiment_entrypoint.?, }; } else if (experiment_name != null) { allocator.free(experiment_name.?); } else if (experiment_entrypoint != null) { allocator.free(experiment_entrypoint.?); } return config; } pub fn save(self: Config, allocator: std.mem.Allocator) !void { const home = std.posix.getenv("HOME") orelse return error.NoHomeDir; // Create .ml directory const ml_dir = try std.fmt.allocPrint(allocator, "{s}/.ml", .{home}); defer allocator.free(ml_dir); std.fs.makeDirAbsolute(ml_dir) catch |err| { if (err != error.PathAlreadyExists) return err; }; const config_path = try std.fmt.allocPrint(allocator, "{s}/config.toml", .{ml_dir}); defer allocator.free(config_path); const file = try std.fs.createFileAbsolute(config_path, .{}); defer file.close(); // Write config directly using fmt.allocPrint and file.writeAll const content = try std.fmt.allocPrint(allocator, \\# FetchML Configuration \\tracking_uri = "{s}" \\artifact_path = "{s}" \\sync_uri = "{s}" \\force_local = {s} \\{s} \\# Server config (for runner mode) \\worker_host = "{s}" \\worker_user = "{s}" \\worker_base = "{s}" \\worker_port = {d} \\api_key = "{s}" \\ \\# Default resource requests \\default_cpu = {d} \\default_memory = {d} \\default_gpu = {d} \\{s} \\# CLI behavior defaults \\default_dry_run = {s} \\default_validate = {s} \\default_json = {s} \\default_priority = {d} \\ , .{ self.tracking_uri, self.artifact_path, self.sync_uri, if (self.force_local) "true" else "false", if (self.experiment) |exp| try std.fmt.allocPrint(allocator, \\n[experiment]\nname = "{s}"\nentrypoint = "{s}"\n , .{ exp.name, exp.entrypoint }) else "", self.worker_host, self.worker_user, self.worker_base, self.worker_port, self.api_key, self.default_cpu, self.default_memory, self.default_gpu, if (self.default_gpu_memory) |gpu_mem| try std.fmt.allocPrint(allocator, \\default_gpu_memory = "{s}"\n , .{gpu_mem}) else "", if (self.default_dry_run) "true" else "false", if (self.default_validate) "true" else "false", if (self.default_json) "true" else "false", self.default_priority, }); defer allocator.free(content); try file.writeAll(content); } pub fn deinit(self: *Config, allocator: std.mem.Allocator) void { allocator.free(self.tracking_uri); allocator.free(self.artifact_path); allocator.free(self.sync_uri); if (self.experiment) |*exp| { allocator.free(exp.name); allocator.free(exp.entrypoint); } allocator.free(self.worker_host); allocator.free(self.worker_user); allocator.free(self.worker_base); allocator.free(self.api_key); if (self.default_gpu_memory) |gpu_mem| { allocator.free(gpu_mem); } } /// Apply settings from another config (for layering) fn apply(self: *Config, other: Config) void { if (other.tracking_uri.len > 0) { self.tracking_uri = other.tracking_uri; } if (other.artifact_path.len > 0) { self.artifact_path = other.artifact_path; } if (other.sync_uri.len > 0) { self.sync_uri = other.sync_uri; } if (other.force_local) { self.force_local = other.force_local; } if (other.experiment) |exp| { if (self.experiment == null) { self.experiment = exp; } } if (other.worker_host.len > 0) { self.worker_host = other.worker_host; } if (other.worker_user.len > 0) { self.worker_user = other.worker_user; } if (other.worker_base.len > 0) { self.worker_base = other.worker_base; } if (other.worker_port != 22) { self.worker_port = other.worker_port; } if (other.api_key.len > 0) { self.api_key = other.api_key; } } /// Deinit a config that was loaded temporarily fn deinitGlobal(self: Config, allocator: std.mem.Allocator, other: Config) void { _ = self; allocator.free(other.tracking_uri); allocator.free(other.artifact_path); allocator.free(other.sync_uri); if (other.experiment) |*exp| { allocator.free(exp.name); allocator.free(exp.entrypoint); } allocator.free(other.worker_host); allocator.free(other.worker_user); allocator.free(other.worker_base); allocator.free(other.api_key); if (other.default_gpu_memory) |gpu_mem| { allocator.free(gpu_mem); } } /// Apply environment variable overrides fn applyEnv(self: *Config, allocator: std.mem.Allocator) void { // FETCHML_* environment variables for URI-based config if (std.posix.getenv("FETCHML_TRACKING_URI")) |uri| { allocator.free(self.tracking_uri); self.tracking_uri = allocator.dupe(u8, uri) catch self.tracking_uri; } if (std.posix.getenv("FETCHML_ARTIFACT_PATH")) |path| { allocator.free(self.artifact_path); self.artifact_path = allocator.dupe(u8, path) catch self.artifact_path; } if (std.posix.getenv("FETCHML_SYNC_URI")) |uri| { allocator.free(self.sync_uri); self.sync_uri = allocator.dupe(u8, uri) catch self.sync_uri; } // Legacy FETCH_ML_CLI_* variables if (std.posix.getenv("FETCH_ML_CLI_HOST")) |host| { allocator.free(self.worker_host); self.worker_host = allocator.dupe(u8, host) catch self.worker_host; } if (std.posix.getenv("FETCH_ML_CLI_USER")) |user| { allocator.free(self.worker_user); self.worker_user = allocator.dupe(u8, user) catch self.worker_user; } if (std.posix.getenv("FETCH_ML_CLI_BASE")) |base| { allocator.free(self.worker_base); self.worker_base = allocator.dupe(u8, base) catch self.worker_base; } if (std.posix.getenv("FETCH_ML_CLI_PORT")) |port_str| { if (std.fmt.parseInt(u16, port_str, 10)) |port| { self.worker_port = port; } else |_| {} } if (std.posix.getenv("FETCH_ML_CLI_API_KEY")) |api_key| { allocator.free(self.api_key); self.api_key = allocator.dupe(u8, api_key) catch self.api_key; } } /// Get WebSocket URL for connecting to the server pub fn getWebSocketUrl(self: Config, allocator: std.mem.Allocator) ![]u8 { const protocol = if (self.worker_port == 443) "wss" else "ws"; return std.fmt.allocPrint(allocator, "{s}://{s}:{d}/ws", .{ protocol, self.worker_host, self.worker_port, }); } }; /// Load global config from ~/.ml/config.toml fn loadGlobalConfig(allocator: std.mem.Allocator) !?Config { const home = std.posix.getenv("HOME") orelse return null; const config_path = try std.fmt.allocPrint(allocator, "{s}/.ml/config.toml", .{home}); defer allocator.free(config_path); const file = std.fs.openFileAbsolute(config_path, .{ .lock = .none }) catch |err| { if (err == error.FileNotFound) return null; return err; }; defer file.close(); return try Config.loadFromFile(allocator, file); } /// Load project config from .fetchml/config.toml in CWD fn loadProjectConfig(allocator: std.mem.Allocator) !?Config { const file = std.fs.openFileAbsolute(".fetchml/config.toml", .{ .lock = .none }) catch |err| { if (err == error.FileNotFound) return null; return err; }; defer file.close(); return try Config.loadFromFile(allocator, file); }